From 60d2c8c86c14cd7d16aa8de71ca0974ab474dea8 Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Fri, 21 Feb 2025 12:44:32 -0800 Subject: [PATCH] improvements --- backend/model_server/custom_models.py | 57 +++++++++++++++++----- backend/onyx/configs/model_configs.py | 17 +++++++ backend/onyx/indexing/indexing_pipeline.py | 9 +++- 3 files changed, 68 insertions(+), 15 deletions(-) diff --git a/backend/model_server/custom_models.py b/backend/model_server/custom_models.py index 2b614f81f5..aa2a1cc2ef 100644 --- a/backend/model_server/custom_models.py +++ b/backend/model_server/custom_models.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torch.nn.functional as F from fastapi import APIRouter @@ -12,6 +13,9 @@ from model_server.constants import MODEL_WARM_UP_STRING from model_server.onyx_torch_model import ConnectorClassifier from model_server.onyx_torch_model import HybridClassifier from model_server.utils import simple_log_function_time +from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_MAX +from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_MIN +from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE from onyx.utils.logger import setup_logger from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG @@ -37,7 +41,9 @@ _INTENT_MODEL: HybridClassifier | None = None _CONTENT_MODEL: SetFitModel | None = None -_TEMPERATURE_CONTENT_CLASSIFICATION = 4.0 +_CONTENT_MODEL_PROMPT_PREFIX: str = ( + "Does this sentence have very specific information: " # spec to model version! +) def get_connector_classifier_tokenizer() -> AutoTokenizer: @@ -247,22 +253,47 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]: def run_content_classification_inference( text_inputs: list[str], ) -> list[tuple[int, float]]: - get_local_content_model() + def _prob_to_score(prob: float) -> float: + if prob < 0.25: + raw_score = 0.0 + elif prob < 0.75: + raw_score = (prob - 0.25) / 0.5 + else: + raw_score = 1.0 + return ( + INDEXING_CONTENT_CLASSIFICATION_MIN + + ( + INDEXING_CONTENT_CLASSIFICATION_MAX + - INDEXING_CONTENT_CLASSIFICATION_MIN + ) + * raw_score + ) - # output_classes = list([x.numpy() for x in content_model(text_inputs)]) - # base_output_probabilities = list([x[1].numpy() for x in content_model.predict_proba(text_inputs)]) - # logits = [np.log(p/(1-p)) for p in base_output_probabilities] - # scaled_logits = [l/_TEMPERATURE_CONTENT_CLASSIFICATION for l in logits] - # output_probabilities_with_temp = [np.exp(scaled_logit)/(1 + np.exp(scaled_logit)) for scaled_logit in scaled_logits] + content_model = get_local_content_model() + + output_classes = list([x.numpy() for x in content_model(text_inputs)]) + base_output_probabilities = list( + [x[1].numpy() for x in content_model.predict_proba(text_inputs)] + ) + logits = [np.log(p / (1 - p)) for p in base_output_probabilities] + scaled_logits = [ + logit / INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE for logit in logits + ] + output_probabilities_with_temp = [ + np.exp(scaled_logit) / (1 + np.exp(scaled_logit)) + for scaled_logit in scaled_logits + ] + + output_scores = [ + _prob_to_score(p_temp) for p_temp in output_probabilities_with_temp + ] output_classes = [1] * len(text_inputs) - output_probabilities_with_temp = [0.9] * len(text_inputs) + output_scores = [0.9] * len(text_inputs) return [ - (predicted_label, predicted_probability) - for predicted_label, predicted_probability in zip( - output_classes, output_probabilities_with_temp - ) + (predicted_label, output_score) + for predicted_label, output_score in zip(output_classes, output_scores) ] @@ -417,6 +448,6 @@ async def process_content_classification_request( content_classification_requests: list[str], ) -> list[tuple[int, float]]: content_classification_result = run_content_classification_inference( - content_classification_requests + [_CONTENT_MODEL_PROMPT_PREFIX + req for req in content_classification_requests] ) return content_classification_result diff --git a/backend/onyx/configs/model_configs.py b/backend/onyx/configs/model_configs.py index 0c85661d6b..b68c609923 100644 --- a/backend/onyx/configs/model_configs.py +++ b/backend/onyx/configs/model_configs.py @@ -132,3 +132,20 @@ if _LITELLM_EXTRA_BODY_RAW: LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW) except Exception: pass + +# Whether and how to lower scores for short chunks w/o relevant context +# Evaluated via custom ML model + +USE_CONTENT_CLASSIFICATION = ( + os.environ.get("USE_CONTENT_CLASSIFICATION") or "true" +).lower() == "true" + +INDEXING_CONTENT_CLASSIFICATION_MIN = float( + os.environ.get("INDEXING_CONTENT_CLASSIFICATION_MIN") or 0.7 +) +INDEXING_CONTENT_CLASSIFICATION_MAX = float( + os.environ.get("INDEXING_CONTENT_CLASSIFICATION_MAX") or 1.0 +) +INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE = float( + os.environ.get("INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0 +) diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index b9c4150e3a..3223419ddc 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -10,6 +10,7 @@ from onyx.access.access import get_access_for_documents from onyx.access.models import DocumentAccess from onyx.configs.app_configs import MAX_DOCUMENT_CHARS from onyx.configs.constants import DEFAULT_BOOST +from onyx.configs.model_configs import USE_CONTENT_CLASSIFICATION from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( get_experts_stores_representations, @@ -603,8 +604,12 @@ def index_doc_batch( chunks_with_embeddings_scores, chunk_content_scores, chunk_content_classification_failures, - ) = _get_aggregated_boost_factor( - chunks_with_embeddings, content_classification_model + ) = ( + _get_aggregated_boost_factor( + chunks_with_embeddings, content_classification_model + ) + if USE_CONTENT_CLASSIFICATION + else (chunks_with_embeddings, [1.0] * len(chunks_with_embeddings), []) ) updatable_ids = [doc.id for doc in ctx.updatable_docs]