mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-29 17:20:44 +02:00
improvements
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
@ -12,6 +13,9 @@ from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.onyx_torch_model import ConnectorClassifier
|
||||
from model_server.onyx_torch_model import HybridClassifier
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_MAX
|
||||
from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_MIN
|
||||
from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
@ -37,7 +41,9 @@ _INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_CONTENT_MODEL: SetFitModel | None = None
|
||||
|
||||
_TEMPERATURE_CONTENT_CLASSIFICATION = 4.0
|
||||
_CONTENT_MODEL_PROMPT_PREFIX: str = (
|
||||
"Does this sentence have very specific information: " # spec to model version!
|
||||
)
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
@ -247,22 +253,47 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
def run_content_classification_inference(
|
||||
text_inputs: list[str],
|
||||
) -> list[tuple[int, float]]:
|
||||
get_local_content_model()
|
||||
def _prob_to_score(prob: float) -> float:
|
||||
if prob < 0.25:
|
||||
raw_score = 0.0
|
||||
elif prob < 0.75:
|
||||
raw_score = (prob - 0.25) / 0.5
|
||||
else:
|
||||
raw_score = 1.0
|
||||
return (
|
||||
INDEXING_CONTENT_CLASSIFICATION_MIN
|
||||
+ (
|
||||
INDEXING_CONTENT_CLASSIFICATION_MAX
|
||||
- INDEXING_CONTENT_CLASSIFICATION_MIN
|
||||
)
|
||||
* raw_score
|
||||
)
|
||||
|
||||
# output_classes = list([x.numpy() for x in content_model(text_inputs)])
|
||||
# base_output_probabilities = list([x[1].numpy() for x in content_model.predict_proba(text_inputs)])
|
||||
# logits = [np.log(p/(1-p)) for p in base_output_probabilities]
|
||||
# scaled_logits = [l/_TEMPERATURE_CONTENT_CLASSIFICATION for l in logits]
|
||||
# output_probabilities_with_temp = [np.exp(scaled_logit)/(1 + np.exp(scaled_logit)) for scaled_logit in scaled_logits]
|
||||
content_model = get_local_content_model()
|
||||
|
||||
output_classes = list([x.numpy() for x in content_model(text_inputs)])
|
||||
base_output_probabilities = list(
|
||||
[x[1].numpy() for x in content_model.predict_proba(text_inputs)]
|
||||
)
|
||||
logits = [np.log(p / (1 - p)) for p in base_output_probabilities]
|
||||
scaled_logits = [
|
||||
logit / INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE for logit in logits
|
||||
]
|
||||
output_probabilities_with_temp = [
|
||||
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
|
||||
for scaled_logit in scaled_logits
|
||||
]
|
||||
|
||||
output_scores = [
|
||||
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
|
||||
]
|
||||
|
||||
output_classes = [1] * len(text_inputs)
|
||||
output_probabilities_with_temp = [0.9] * len(text_inputs)
|
||||
output_scores = [0.9] * len(text_inputs)
|
||||
|
||||
return [
|
||||
(predicted_label, predicted_probability)
|
||||
for predicted_label, predicted_probability in zip(
|
||||
output_classes, output_probabilities_with_temp
|
||||
)
|
||||
(predicted_label, output_score)
|
||||
for predicted_label, output_score in zip(output_classes, output_scores)
|
||||
]
|
||||
|
||||
|
||||
@ -417,6 +448,6 @@ async def process_content_classification_request(
|
||||
content_classification_requests: list[str],
|
||||
) -> list[tuple[int, float]]:
|
||||
content_classification_result = run_content_classification_inference(
|
||||
content_classification_requests
|
||||
[_CONTENT_MODEL_PROMPT_PREFIX + req for req in content_classification_requests]
|
||||
)
|
||||
return content_classification_result
|
||||
|
@ -132,3 +132,20 @@ if _LITELLM_EXTRA_BODY_RAW:
|
||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Whether and how to lower scores for short chunks w/o relevant context
|
||||
# Evaluated via custom ML model
|
||||
|
||||
USE_CONTENT_CLASSIFICATION = (
|
||||
os.environ.get("USE_CONTENT_CLASSIFICATION") or "true"
|
||||
).lower() == "true"
|
||||
|
||||
INDEXING_CONTENT_CLASSIFICATION_MIN = float(
|
||||
os.environ.get("INDEXING_CONTENT_CLASSIFICATION_MIN") or 0.7
|
||||
)
|
||||
INDEXING_CONTENT_CLASSIFICATION_MAX = float(
|
||||
os.environ.get("INDEXING_CONTENT_CLASSIFICATION_MAX") or 1.0
|
||||
)
|
||||
INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE = float(
|
||||
os.environ.get("INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ from onyx.access.access import get_access_for_documents
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.model_configs import USE_CONTENT_CLASSIFICATION
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
@ -603,9 +604,13 @@ def index_doc_batch(
|
||||
chunks_with_embeddings_scores,
|
||||
chunk_content_scores,
|
||||
chunk_content_classification_failures,
|
||||
) = _get_aggregated_boost_factor(
|
||||
) = (
|
||||
_get_aggregated_boost_factor(
|
||||
chunks_with_embeddings, content_classification_model
|
||||
)
|
||||
if USE_CONTENT_CLASSIFICATION
|
||||
else (chunks_with_embeddings, [1.0] * len(chunks_with_embeddings), [])
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
||||
|
Reference in New Issue
Block a user