mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
YS comments - pt 1
This commit is contained in:
parent
b8f64d10a2
commit
ef291fcf0c
@ -16,6 +16,9 @@ from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
from shared_configs.configs import (
|
||||
@ -47,20 +50,6 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def _create_local_path(
|
||||
model_name_or_path: str, tag: str | None, local_files_only: bool
|
||||
) -> str:
|
||||
if tag is None:
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=local_files_only
|
||||
)
|
||||
return local_path
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
@ -298,11 +287,23 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
def run_content_classification_inference(
|
||||
text_inputs: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
"""
|
||||
Assign a score to the segments in question. The model stored in get_local_information_content_model()
|
||||
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
|
||||
In the code outside of the model/inference model servers that score will be converted into the actual
|
||||
boost factor.
|
||||
"""
|
||||
|
||||
def _prob_to_score(prob: float) -> float:
|
||||
if prob < 0.25:
|
||||
"""
|
||||
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
|
||||
"""
|
||||
_MIN_BASE_SCORE = 0.25
|
||||
_MAX_BASE_SCORE = 0.75
|
||||
if prob < _MIN_BASE_SCORE:
|
||||
raw_score = 0.0
|
||||
elif prob < 0.75:
|
||||
raw_score = (prob - 0.25) / 0.5
|
||||
elif prob < _MAX_BASE_SCORE:
|
||||
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
|
||||
else:
|
||||
raw_score = 1.0
|
||||
return (
|
||||
@ -314,13 +315,62 @@ def run_content_classification_inference(
|
||||
* raw_score
|
||||
)
|
||||
|
||||
_BATCH_SIZE = 32
|
||||
content_model = get_local_information_content_model()
|
||||
|
||||
output_classes = list([x.numpy() for x in content_model(text_inputs)])
|
||||
base_output_probabilities = list(
|
||||
[x[1].numpy() for x in content_model.predict_proba(text_inputs)]
|
||||
)
|
||||
logits = [np.log(p / (1 - p)) for p in base_output_probabilities]
|
||||
# Process inputs in batches
|
||||
all_output_classes: list[int] = []
|
||||
all_base_output_probabilities: list[float] = []
|
||||
|
||||
for i in range(0, len(text_inputs), _BATCH_SIZE):
|
||||
batch = text_inputs[i : i + _BATCH_SIZE]
|
||||
batch_with_prefix = []
|
||||
batch_indices = []
|
||||
|
||||
# Pre-allocate results for this batch
|
||||
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
|
||||
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
|
||||
|
||||
# Pre-process batch to handle long input exceptions
|
||||
for j, text in enumerate(batch):
|
||||
if len(text) == 0:
|
||||
# if no input, treat as non-informative from the model's perspective
|
||||
batch_output_classes[j] = np.array(0)
|
||||
batch_probabilities[j] = np.array(0.0)
|
||||
logger.warning("Input for Content Information Model is empty")
|
||||
|
||||
elif (
|
||||
len(text.split())
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
):
|
||||
# if input is short, use the model
|
||||
batch_with_prefix.append(
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
|
||||
)
|
||||
batch_indices.append(j)
|
||||
else:
|
||||
# if longer than cutoff, treat as informative (stay with default), but issue warning
|
||||
logger.warning("Input for Content Information Model too long")
|
||||
|
||||
if batch_with_prefix: # Only run model if we have valid inputs
|
||||
# Get predictions for the batch
|
||||
model_output_classes = content_model(batch_with_prefix)
|
||||
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
|
||||
|
||||
# Place results in the correct positions
|
||||
for idx, batch_idx in enumerate(batch_indices):
|
||||
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
|
||||
batch_probabilities[batch_idx] = model_output_probabilities[idx][
|
||||
1
|
||||
].numpy() # x[1] is prob of the positive class
|
||||
|
||||
all_output_classes.extend([int(x) for x in batch_output_classes])
|
||||
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
|
||||
|
||||
logits = [
|
||||
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
|
||||
for p in all_base_output_probabilities
|
||||
]
|
||||
scaled_logits = [
|
||||
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||
for logit in logits
|
||||
@ -338,7 +388,7 @@ def run_content_classification_inference(
|
||||
ContentClassificationPrediction(
|
||||
predicted_label=predicted_label, content_boost_factor=output_score
|
||||
)
|
||||
for predicted_label, output_score in zip(output_classes, prediction_scores)
|
||||
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
|
||||
]
|
||||
|
||||
return content_classification_predictions
|
||||
@ -494,9 +544,4 @@ async def process_analysis_request(
|
||||
async def process_content_classification_request(
|
||||
content_classification_requests: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
return run_content_classification_inference(
|
||||
[
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + req
|
||||
for req in content_classification_requests
|
||||
]
|
||||
)
|
||||
return run_content_classification_inference(content_classification_requests)
|
||||
|
@ -75,11 +75,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice(
|
||||
"The intent model should run on the model server. The information content model should not run here."
|
||||
)
|
||||
warm_up_intent_model()
|
||||
else:
|
||||
logger.notice("This model server should only run document indexing.")
|
||||
|
||||
warm_up_information_content_model()
|
||||
logger.notice(
|
||||
"The content information model should run on the indexing model server. The intent model should not run here."
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
|
||||
yield
|
||||
|
||||
|
@ -563,6 +563,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
# aggregated_boost_factor=doc.aggregated_boost_factor,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
|
@ -60,8 +60,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
@ -353,9 +351,7 @@ def _run_indexing(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel(
|
||||
model_server_host=MODEL_SERVER_HOST, model_server_port=MODEL_SERVER_PORT
|
||||
)
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
document_index = get_default_document_index(
|
||||
index_attempt_start.search_settings,
|
||||
|
@ -139,8 +139,3 @@ if _LITELLM_EXTRA_BODY_RAW:
|
||||
USE_INFORMATION_CONTENT_CLASSIFICATION = (
|
||||
os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Cutoff below which we start using the information content classification model
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = float(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10
|
||||
)
|
||||
|
@ -10,9 +10,6 @@ from onyx.access.access import get_access_for_documents
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.model_configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
@ -64,6 +61,9 @@ from onyx.natural_language_processing.search_nlp_models import (
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -147,7 +147,7 @@ def _upsert_documents_in_db(
|
||||
def _get_aggregated_boost_factor(
|
||||
chunks: list[IndexChunk],
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
) -> tuple[list[IndexChunk], list[float], list[ConnectorFailure]]:
|
||||
) -> list[float]:
|
||||
"""Calculates the aggregated boost factor for a chunk based on its content."""
|
||||
|
||||
short_chunk_content_dict = {
|
||||
@ -171,7 +171,7 @@ def _get_aggregated_boost_factor(
|
||||
# Default to 1.0 for longer chunks, use predicted score for short chunks
|
||||
chunk_content_scores = [score_map.get(i, 1.0) for i in range(len(chunks))]
|
||||
|
||||
return chunks, chunk_content_scores, []
|
||||
return chunk_content_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
@ -180,7 +180,6 @@ def _get_aggregated_boost_factor(
|
||||
|
||||
chunks_with_scores: list[IndexChunk] = []
|
||||
chunk_content_scores = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
|
||||
for chunk in chunks:
|
||||
if (
|
||||
@ -200,25 +199,15 @@ def _get_aggregated_boost_factor(
|
||||
chunks_with_scores.append(chunk)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error predicting content classification for chunk: {e}. Adding to missed content classifications."
|
||||
f"Error predicting content classification for chunk: {e}."
|
||||
)
|
||||
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=chunk.source_document.id,
|
||||
document_link=(
|
||||
chunk.source_document.sections[0].link
|
||||
if chunk.source_document.sections
|
||||
else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
raise Exception(
|
||||
f"Failed to predict content classification for chunk {chunk.chunk_id} "
|
||||
f"from document {chunk.source_document.id}"
|
||||
) from e
|
||||
|
||||
return chunks_with_scores, chunk_content_scores, failures
|
||||
return chunk_content_scores
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
@ -619,11 +608,13 @@ def index_doc_batch(
|
||||
chunk_content_scores,
|
||||
chunk_content_classification_failures,
|
||||
) = (
|
||||
chunks_with_embeddings,
|
||||
_get_aggregated_boost_factor(
|
||||
chunks_with_embeddings, information_content_classification_model
|
||||
)
|
||||
if USE_INFORMATION_CONTENT_CLASSIFICATION
|
||||
else (chunks_with_embeddings, [1.0] * len(chunks_with_embeddings), [])
|
||||
else [1.0] * len(chunks_with_embeddings),
|
||||
embedding_failures,
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
@ -85,14 +85,14 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
boost: influences the ranking of this chunk at query time. Positive -> ranked higher,
|
||||
negative -> ranked lower. Not included in aggregated boost calculation
|
||||
for legacy reasons.
|
||||
aggregated_boost_factor: represents non-user-specific aggregated boost calculation
|
||||
aggregated_boost_factor: represents the content information boost calculation
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
aggregated_boost_factor: float = 1.0
|
||||
aggregated_boost_factor: float
|
||||
|
||||
@classmethod
|
||||
def from_index_chunk(
|
||||
|
@ -29,6 +29,8 @@ from onyx.natural_language_processing.exceptions import (
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
@ -382,8 +384,8 @@ class QueryAnalysisModel:
|
||||
class InformationContentClassificationModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
model_server_host: str = INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port: int = INDEXING_MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.content_server_endpoint = (
|
||||
@ -397,11 +399,11 @@ class InformationContentClassificationModel:
|
||||
response = requests.post(self.content_server_endpoint, json=queries)
|
||||
response.raise_for_status()
|
||||
|
||||
response_model = InformationContentClassificationResponses(
|
||||
model_responses = InformationContentClassificationResponses(
|
||||
information_content_classifications=response.json()
|
||||
)
|
||||
|
||||
return response_model.information_content_classifications
|
||||
return model_responses.information_content_classifications
|
||||
|
||||
|
||||
class ConnectorClassificationModel:
|
||||
|
@ -98,6 +98,7 @@ def _create_indexable_chunks(
|
||||
boost=DEFAULT_BOOST,
|
||||
large_chunk_id=None,
|
||||
image_file_name=None,
|
||||
aggregated_boost_factor=1.0,
|
||||
)
|
||||
|
||||
chunks.append(chunk)
|
||||
|
@ -26,8 +26,6 @@ from onyx.server.onyx_api.models import DocMinimalInfo
|
||||
from onyx.server.onyx_api.models import IngestionDocument
|
||||
from onyx.server.onyx_api.models import IngestionResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@ -107,9 +105,7 @@ def upsert_ingestion_doc(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel(
|
||||
model_server_host=MODEL_SERVER_HOST, model_server_port=MODEL_SERVER_PORT
|
||||
)
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=index_embedding_model,
|
||||
|
@ -291,3 +291,8 @@ INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float(
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
|
||||
)
|
||||
# Cutoff below which we start using the information content classification model
|
||||
# (cutoff length number itself is still considered 'short'))
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10
|
||||
)
|
||||
|
145
backend/tests/unit/model_server/test_custom_models.py
Normal file
145
backend/tests/unit/model_server/test_custom_models.py
Normal file
@ -0,0 +1,145 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pytest
|
||||
|
||||
from model_server.custom_models import run_content_classification_inference
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_content_model() -> Mock:
|
||||
model = Mock()
|
||||
|
||||
# Create actual numpy arrays for the mock returns
|
||||
predict_output = np.array(
|
||||
[1, 0] * 50, dtype=np.int64
|
||||
) # Pre-allocate enough elements
|
||||
proba_output = np.array(
|
||||
[[0.3, 0.7], [0.7, 0.3]] * 50, dtype=np.float64
|
||||
) # Pre-allocate enough elements
|
||||
|
||||
# Create a mock tensor that has a numpy method and supports indexing
|
||||
class MockTensor:
|
||||
def __init__(self, value: npt.NDArray[Any]) -> None:
|
||||
self.value = value
|
||||
|
||||
def numpy(self) -> npt.NDArray[Any]:
|
||||
return self.value
|
||||
|
||||
def __getitem__(self, idx: Any) -> Any:
|
||||
result = self.value[idx]
|
||||
# Wrap scalar values back in MockTensor
|
||||
if isinstance(result, (np.float64, np.int64)):
|
||||
return MockTensor(np.array([result]))
|
||||
return MockTensor(result)
|
||||
|
||||
# Mock the direct call to return a MockTensor for each input
|
||||
def model_call(inputs: list[str]) -> list[MockTensor]:
|
||||
batch_size = len(inputs)
|
||||
return [MockTensor(predict_output[i : i + 1]) for i in range(batch_size)]
|
||||
|
||||
model.side_effect = model_call
|
||||
|
||||
# Mock predict_proba to return MockTensor-wrapped numpy array
|
||||
def predict_proba_call(x: list[str]) -> MockTensor:
|
||||
batch_size = len(x)
|
||||
return MockTensor(proba_output[:batch_size])
|
||||
|
||||
model.predict_proba.side_effect = predict_proba_call
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@patch("model_server.custom_models.get_local_information_content_model")
|
||||
def test_run_content_classification_inference(
|
||||
mock_get_model: Mock,
|
||||
mock_content_model: Mock,
|
||||
) -> None:
|
||||
"""
|
||||
Test the content classification inference function.
|
||||
Verifies that the function correctly processes text inputs and returns appropriate predictions.
|
||||
"""
|
||||
# Setup
|
||||
mock_get_model.return_value = mock_content_model
|
||||
|
||||
test_inputs = [
|
||||
"Imagine a short text with content",
|
||||
"Imagine a short text without content",
|
||||
"x "
|
||||
* (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1
|
||||
), # Long input that exceeds maximal length for when the model should be applied
|
||||
"", # Empty input
|
||||
]
|
||||
|
||||
# Execute
|
||||
results = run_content_classification_inference(test_inputs)
|
||||
|
||||
# Assert
|
||||
assert len(results) == len(test_inputs)
|
||||
assert all(isinstance(r, ContentClassificationPrediction) for r in results)
|
||||
|
||||
# Check each prediction has expected attributes and ranges
|
||||
for result_num, result in enumerate(results):
|
||||
assert hasattr(result, "predicted_label")
|
||||
assert hasattr(result, "content_boost_factor")
|
||||
assert isinstance(result.predicted_label, int)
|
||||
assert isinstance(result.content_boost_factor, float)
|
||||
assert (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
<= result.content_boost_factor
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
)
|
||||
if result_num == 2:
|
||||
assert (
|
||||
result.content_boost_factor
|
||||
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
)
|
||||
assert result.predicted_label == 1
|
||||
elif result_num == 3:
|
||||
assert (
|
||||
result.content_boost_factor
|
||||
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
)
|
||||
assert result.predicted_label == 0
|
||||
|
||||
# Verify model handling of long inputs
|
||||
mock_content_model.predict_proba.reset_mock()
|
||||
long_input = ["x " * 1000] # Definitely exceeds MAX_LENGTH
|
||||
results = run_content_classification_inference(long_input)
|
||||
assert len(results) == 1
|
||||
assert (
|
||||
mock_content_model.predict_proba.call_count == 0
|
||||
) # Should skip model call for too-long input
|
||||
|
||||
|
||||
@patch("model_server.custom_models.get_local_information_content_model")
|
||||
def test_batch_processing(
|
||||
mock_get_model: Mock,
|
||||
mock_content_model: Mock,
|
||||
) -> None:
|
||||
"""
|
||||
Test that the function correctly handles batch processing of inputs.
|
||||
"""
|
||||
# Setup
|
||||
mock_get_model.return_value = mock_content_model
|
||||
|
||||
# Create test input larger than batch size
|
||||
test_inputs = [f"Test input {i}" for i in range(40)] # > BATCH_SIZE (32)
|
||||
|
||||
# Execute
|
||||
results = run_content_classification_inference(test_inputs)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 40
|
||||
# Verify batching occurred (should have called predict_proba twice)
|
||||
assert mock_content_model.predict_proba.call_count == 2
|
@ -1,12 +1,24 @@
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.indexing_pipeline import _get_aggregated_boost_factor
|
||||
from onyx.indexing.indexing_pipeline import filter_documents
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
ContentClassificationPrediction,
|
||||
)
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
def create_test_document(
|
||||
@ -123,3 +135,117 @@ def test_filter_documents_multiple_documents() -> None:
|
||||
def test_filter_documents_empty_batch() -> None:
|
||||
result = filter_documents([])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
# Tests for get_aggregated_boost_factor
|
||||
|
||||
|
||||
def create_test_chunk(
|
||||
content: str, chunk_id: int = 0, doc_id: str = "test_doc"
|
||||
) -> IndexChunk:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test doc",
|
||||
sections=[],
|
||||
source=DocumentSource.FILE,
|
||||
metadata={},
|
||||
)
|
||||
return IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
content=content,
|
||||
source_document=doc,
|
||||
blurb=content[:50], # First 50 chars as blurb
|
||||
source_links={0: "test_link"},
|
||||
section_continuation=False,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
large_chunk_reference_ids=[],
|
||||
embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]),
|
||||
title_embedding=None,
|
||||
image_file_name=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_aggregated_boost_factor() -> None:
|
||||
# Create test chunks - mix of short and long content
|
||||
chunks = [
|
||||
create_test_chunk("Short content", 0),
|
||||
create_test_chunk(
|
||||
"Long " * (INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1), 1
|
||||
),
|
||||
create_test_chunk("Another short chunk", 2),
|
||||
]
|
||||
|
||||
# Mock the classification model
|
||||
mock_model = Mock()
|
||||
mock_model.predict.return_value = [
|
||||
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8),
|
||||
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.9),
|
||||
]
|
||||
|
||||
# Execute the function
|
||||
boost_scores = _get_aggregated_boost_factor(
|
||||
chunks=chunks, information_content_classification_model=mock_model
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert len(boost_scores) == 3
|
||||
|
||||
# Check that long content got default boost
|
||||
assert boost_scores[1] == 1.0
|
||||
|
||||
# Check that short content got predicted boosts
|
||||
assert boost_scores[0] == 0.8
|
||||
assert boost_scores[2] == 0.9
|
||||
|
||||
# Verify model was only called once with the short chunks
|
||||
mock_model.predict.assert_called_once()
|
||||
assert len(mock_model.predict.call_args[0][0]) == 2
|
||||
|
||||
|
||||
def test_get_aggregated_boost_factorilure() -> None:
|
||||
chunks = [
|
||||
create_test_chunk("Short content 1", 0),
|
||||
create_test_chunk("Short content 2", 1),
|
||||
]
|
||||
|
||||
# Mock model to fail on batch prediction but succeed on individual predictions
|
||||
mock_model = Mock()
|
||||
mock_model.predict.side_effect = [
|
||||
Exception("Batch prediction failed"), # First call fails
|
||||
[
|
||||
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.7)
|
||||
], # Individual calls succeed
|
||||
[ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8)],
|
||||
]
|
||||
|
||||
# Execute
|
||||
boost_scores = _get_aggregated_boost_factor(
|
||||
chunks=chunks, information_content_classification_model=mock_model
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert len(boost_scores) == 2
|
||||
assert boost_scores == [0.7, 0.8]
|
||||
|
||||
|
||||
def test_get_aggregated_boost_factor_individual_failure() -> None:
|
||||
chunks = [
|
||||
create_test_chunk("Short content", 0),
|
||||
create_test_chunk("Short content", 1),
|
||||
]
|
||||
|
||||
# Mock model to fail on both batch and individual prediction
|
||||
mock_model = Mock()
|
||||
mock_model.predict.side_effect = Exception("Prediction failed")
|
||||
|
||||
# Execute and verify it raises an exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
_get_aggregated_boost_factor(
|
||||
chunks=chunks, information_content_classification_model=mock_model
|
||||
)
|
||||
|
||||
assert "Failed to predict content classification for chunk" in str(exc_info.value)
|
||||
|
Loading…
x
Reference in New Issue
Block a user