improvements

This commit is contained in:
joachim-danswer
2025-02-21 12:44:32 -08:00
parent 324b8e42a5
commit 60d2c8c86c
3 changed files with 68 additions and 15 deletions

View File

@ -1,3 +1,4 @@
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
@ -12,6 +13,9 @@ from model_server.constants import MODEL_WARM_UP_STRING
from model_server.onyx_torch_model import ConnectorClassifier
from model_server.onyx_torch_model import HybridClassifier
from model_server.utils import simple_log_function_time
from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_MAX
from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_MIN
from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE
from onyx.utils.logger import setup_logger
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
@ -37,7 +41,9 @@ _INTENT_MODEL: HybridClassifier | None = None
_CONTENT_MODEL: SetFitModel | None = None
_TEMPERATURE_CONTENT_CLASSIFICATION = 4.0
_CONTENT_MODEL_PROMPT_PREFIX: str = (
"Does this sentence have very specific information: " # spec to model version!
)
def get_connector_classifier_tokenizer() -> AutoTokenizer:
@ -247,22 +253,47 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
def run_content_classification_inference(
text_inputs: list[str],
) -> list[tuple[int, float]]:
get_local_content_model()
def _prob_to_score(prob: float) -> float:
if prob < 0.25:
raw_score = 0.0
elif prob < 0.75:
raw_score = (prob - 0.25) / 0.5
else:
raw_score = 1.0
return (
INDEXING_CONTENT_CLASSIFICATION_MIN
+ (
INDEXING_CONTENT_CLASSIFICATION_MAX
- INDEXING_CONTENT_CLASSIFICATION_MIN
)
* raw_score
)
# output_classes = list([x.numpy() for x in content_model(text_inputs)])
# base_output_probabilities = list([x[1].numpy() for x in content_model.predict_proba(text_inputs)])
# logits = [np.log(p/(1-p)) for p in base_output_probabilities]
# scaled_logits = [l/_TEMPERATURE_CONTENT_CLASSIFICATION for l in logits]
# output_probabilities_with_temp = [np.exp(scaled_logit)/(1 + np.exp(scaled_logit)) for scaled_logit in scaled_logits]
content_model = get_local_content_model()
output_classes = list([x.numpy() for x in content_model(text_inputs)])
base_output_probabilities = list(
[x[1].numpy() for x in content_model.predict_proba(text_inputs)]
)
logits = [np.log(p / (1 - p)) for p in base_output_probabilities]
scaled_logits = [
logit / INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE for logit in logits
]
output_probabilities_with_temp = [
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
for scaled_logit in scaled_logits
]
output_scores = [
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
]
output_classes = [1] * len(text_inputs)
output_probabilities_with_temp = [0.9] * len(text_inputs)
output_scores = [0.9] * len(text_inputs)
return [
(predicted_label, predicted_probability)
for predicted_label, predicted_probability in zip(
output_classes, output_probabilities_with_temp
)
(predicted_label, output_score)
for predicted_label, output_score in zip(output_classes, output_scores)
]
@ -417,6 +448,6 @@ async def process_content_classification_request(
content_classification_requests: list[str],
) -> list[tuple[int, float]]:
content_classification_result = run_content_classification_inference(
content_classification_requests
[_CONTENT_MODEL_PROMPT_PREFIX + req for req in content_classification_requests]
)
return content_classification_result

View File

@ -132,3 +132,20 @@ if _LITELLM_EXTRA_BODY_RAW:
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
except Exception:
pass
# Whether and how to lower scores for short chunks w/o relevant context
# Evaluated via custom ML model
USE_CONTENT_CLASSIFICATION = (
os.environ.get("USE_CONTENT_CLASSIFICATION") or "true"
).lower() == "true"
INDEXING_CONTENT_CLASSIFICATION_MIN = float(
os.environ.get("INDEXING_CONTENT_CLASSIFICATION_MIN") or 0.7
)
INDEXING_CONTENT_CLASSIFICATION_MAX = float(
os.environ.get("INDEXING_CONTENT_CLASSIFICATION_MAX") or 1.0
)
INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE = float(
os.environ.get("INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
)

View File

@ -10,6 +10,7 @@ from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.model_configs import USE_CONTENT_CLASSIFICATION
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
@ -603,9 +604,13 @@ def index_doc_batch(
chunks_with_embeddings_scores,
chunk_content_scores,
chunk_content_classification_failures,
) = _get_aggregated_boost_factor(
) = (
_get_aggregated_boost_factor(
chunks_with_embeddings, content_classification_model
)
if USE_CONTENT_CLASSIFICATION
else (chunks_with_embeddings, [1.0] * len(chunks_with_embeddings), [])
)
updatable_ids = [doc.id for doc in ctx.updatable_docs]