mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-30 01:30:45 +02:00
improvements
This commit is contained in:
@ -1,3 +1,4 @@
|
|||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
@ -12,6 +13,9 @@ from model_server.constants import MODEL_WARM_UP_STRING
|
|||||||
from model_server.onyx_torch_model import ConnectorClassifier
|
from model_server.onyx_torch_model import ConnectorClassifier
|
||||||
from model_server.onyx_torch_model import HybridClassifier
|
from model_server.onyx_torch_model import HybridClassifier
|
||||||
from model_server.utils import simple_log_function_time
|
from model_server.utils import simple_log_function_time
|
||||||
|
from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_MAX
|
||||||
|
from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_MIN
|
||||||
|
from onyx.configs.model_configs import INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||||
@ -37,7 +41,9 @@ _INTENT_MODEL: HybridClassifier | None = None
|
|||||||
|
|
||||||
_CONTENT_MODEL: SetFitModel | None = None
|
_CONTENT_MODEL: SetFitModel | None = None
|
||||||
|
|
||||||
_TEMPERATURE_CONTENT_CLASSIFICATION = 4.0
|
_CONTENT_MODEL_PROMPT_PREFIX: str = (
|
||||||
|
"Does this sentence have very specific information: " # spec to model version!
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||||
@ -247,22 +253,47 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
|||||||
def run_content_classification_inference(
|
def run_content_classification_inference(
|
||||||
text_inputs: list[str],
|
text_inputs: list[str],
|
||||||
) -> list[tuple[int, float]]:
|
) -> list[tuple[int, float]]:
|
||||||
get_local_content_model()
|
def _prob_to_score(prob: float) -> float:
|
||||||
|
if prob < 0.25:
|
||||||
|
raw_score = 0.0
|
||||||
|
elif prob < 0.75:
|
||||||
|
raw_score = (prob - 0.25) / 0.5
|
||||||
|
else:
|
||||||
|
raw_score = 1.0
|
||||||
|
return (
|
||||||
|
INDEXING_CONTENT_CLASSIFICATION_MIN
|
||||||
|
+ (
|
||||||
|
INDEXING_CONTENT_CLASSIFICATION_MAX
|
||||||
|
- INDEXING_CONTENT_CLASSIFICATION_MIN
|
||||||
|
)
|
||||||
|
* raw_score
|
||||||
|
)
|
||||||
|
|
||||||
# output_classes = list([x.numpy() for x in content_model(text_inputs)])
|
content_model = get_local_content_model()
|
||||||
# base_output_probabilities = list([x[1].numpy() for x in content_model.predict_proba(text_inputs)])
|
|
||||||
# logits = [np.log(p/(1-p)) for p in base_output_probabilities]
|
output_classes = list([x.numpy() for x in content_model(text_inputs)])
|
||||||
# scaled_logits = [l/_TEMPERATURE_CONTENT_CLASSIFICATION for l in logits]
|
base_output_probabilities = list(
|
||||||
# output_probabilities_with_temp = [np.exp(scaled_logit)/(1 + np.exp(scaled_logit)) for scaled_logit in scaled_logits]
|
[x[1].numpy() for x in content_model.predict_proba(text_inputs)]
|
||||||
|
)
|
||||||
|
logits = [np.log(p / (1 - p)) for p in base_output_probabilities]
|
||||||
|
scaled_logits = [
|
||||||
|
logit / INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE for logit in logits
|
||||||
|
]
|
||||||
|
output_probabilities_with_temp = [
|
||||||
|
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
|
||||||
|
for scaled_logit in scaled_logits
|
||||||
|
]
|
||||||
|
|
||||||
|
output_scores = [
|
||||||
|
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
|
||||||
|
]
|
||||||
|
|
||||||
output_classes = [1] * len(text_inputs)
|
output_classes = [1] * len(text_inputs)
|
||||||
output_probabilities_with_temp = [0.9] * len(text_inputs)
|
output_scores = [0.9] * len(text_inputs)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
(predicted_label, predicted_probability)
|
(predicted_label, output_score)
|
||||||
for predicted_label, predicted_probability in zip(
|
for predicted_label, output_score in zip(output_classes, output_scores)
|
||||||
output_classes, output_probabilities_with_temp
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -417,6 +448,6 @@ async def process_content_classification_request(
|
|||||||
content_classification_requests: list[str],
|
content_classification_requests: list[str],
|
||||||
) -> list[tuple[int, float]]:
|
) -> list[tuple[int, float]]:
|
||||||
content_classification_result = run_content_classification_inference(
|
content_classification_result = run_content_classification_inference(
|
||||||
content_classification_requests
|
[_CONTENT_MODEL_PROMPT_PREFIX + req for req in content_classification_requests]
|
||||||
)
|
)
|
||||||
return content_classification_result
|
return content_classification_result
|
||||||
|
@ -132,3 +132,20 @@ if _LITELLM_EXTRA_BODY_RAW:
|
|||||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Whether and how to lower scores for short chunks w/o relevant context
|
||||||
|
# Evaluated via custom ML model
|
||||||
|
|
||||||
|
USE_CONTENT_CLASSIFICATION = (
|
||||||
|
os.environ.get("USE_CONTENT_CLASSIFICATION") or "true"
|
||||||
|
).lower() == "true"
|
||||||
|
|
||||||
|
INDEXING_CONTENT_CLASSIFICATION_MIN = float(
|
||||||
|
os.environ.get("INDEXING_CONTENT_CLASSIFICATION_MIN") or 0.7
|
||||||
|
)
|
||||||
|
INDEXING_CONTENT_CLASSIFICATION_MAX = float(
|
||||||
|
os.environ.get("INDEXING_CONTENT_CLASSIFICATION_MAX") or 1.0
|
||||||
|
)
|
||||||
|
INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE = float(
|
||||||
|
os.environ.get("INDEXING_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
|
||||||
|
)
|
||||||
|
@ -10,6 +10,7 @@ from onyx.access.access import get_access_for_documents
|
|||||||
from onyx.access.models import DocumentAccess
|
from onyx.access.models import DocumentAccess
|
||||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||||
from onyx.configs.constants import DEFAULT_BOOST
|
from onyx.configs.constants import DEFAULT_BOOST
|
||||||
|
from onyx.configs.model_configs import USE_CONTENT_CLASSIFICATION
|
||||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||||
get_experts_stores_representations,
|
get_experts_stores_representations,
|
||||||
@ -603,9 +604,13 @@ def index_doc_batch(
|
|||||||
chunks_with_embeddings_scores,
|
chunks_with_embeddings_scores,
|
||||||
chunk_content_scores,
|
chunk_content_scores,
|
||||||
chunk_content_classification_failures,
|
chunk_content_classification_failures,
|
||||||
) = _get_aggregated_boost_factor(
|
) = (
|
||||||
|
_get_aggregated_boost_factor(
|
||||||
chunks_with_embeddings, content_classification_model
|
chunks_with_embeddings, content_classification_model
|
||||||
)
|
)
|
||||||
|
if USE_CONTENT_CLASSIFICATION
|
||||||
|
else (chunks_with_embeddings, [1.0] * len(chunks_with_embeddings), [])
|
||||||
|
)
|
||||||
|
|
||||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user