YS comments - pt 1

This commit is contained in:
joachim-danswer 2025-03-09 16:18:02 -07:00
parent b8f64d10a2
commit ef291fcf0c
13 changed files with 383 additions and 76 deletions

View File

@ -16,6 +16,9 @@ from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.configs import ( from shared_configs.configs import (
@ -47,20 +50,6 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version! _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def _create_local_path(
model_name_or_path: str, tag: str | None, local_files_only: bool
) -> str:
if tag is None:
local_path = snapshot_download(
repo_id=model_name_or_path, local_files_only=local_files_only
)
else:
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=local_files_only
)
return local_path
def get_connector_classifier_tokenizer() -> AutoTokenizer: def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER global _CONNECTOR_CLASSIFIER_TOKENIZER
if _CONNECTOR_CLASSIFIER_TOKENIZER is None: if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
@ -298,11 +287,23 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
def run_content_classification_inference( def run_content_classification_inference(
text_inputs: list[str], text_inputs: list[str],
) -> list[ContentClassificationPrediction]: ) -> list[ContentClassificationPrediction]:
"""
Assign a score to the segments in question. The model stored in get_local_information_content_model()
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
In the code outside of the model/inference model servers that score will be converted into the actual
boost factor.
"""
def _prob_to_score(prob: float) -> float: def _prob_to_score(prob: float) -> float:
if prob < 0.25: """
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
"""
_MIN_BASE_SCORE = 0.25
_MAX_BASE_SCORE = 0.75
if prob < _MIN_BASE_SCORE:
raw_score = 0.0 raw_score = 0.0
elif prob < 0.75: elif prob < _MAX_BASE_SCORE:
raw_score = (prob - 0.25) / 0.5 raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
else: else:
raw_score = 1.0 raw_score = 1.0
return ( return (
@ -314,13 +315,62 @@ def run_content_classification_inference(
* raw_score * raw_score
) )
_BATCH_SIZE = 32
content_model = get_local_information_content_model() content_model = get_local_information_content_model()
output_classes = list([x.numpy() for x in content_model(text_inputs)]) # Process inputs in batches
base_output_probabilities = list( all_output_classes: list[int] = []
[x[1].numpy() for x in content_model.predict_proba(text_inputs)] all_base_output_probabilities: list[float] = []
)
logits = [np.log(p / (1 - p)) for p in base_output_probabilities] for i in range(0, len(text_inputs), _BATCH_SIZE):
batch = text_inputs[i : i + _BATCH_SIZE]
batch_with_prefix = []
batch_indices = []
# Pre-allocate results for this batch
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
# Pre-process batch to handle long input exceptions
for j, text in enumerate(batch):
if len(text) == 0:
# if no input, treat as non-informative from the model's perspective
batch_output_classes[j] = np.array(0)
batch_probabilities[j] = np.array(0.0)
logger.warning("Input for Content Information Model is empty")
elif (
len(text.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
# if input is short, use the model
batch_with_prefix.append(
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
)
batch_indices.append(j)
else:
# if longer than cutoff, treat as informative (stay with default), but issue warning
logger.warning("Input for Content Information Model too long")
if batch_with_prefix: # Only run model if we have valid inputs
# Get predictions for the batch
model_output_classes = content_model(batch_with_prefix)
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
# Place results in the correct positions
for idx, batch_idx in enumerate(batch_indices):
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
batch_probabilities[batch_idx] = model_output_probabilities[idx][
1
].numpy() # x[1] is prob of the positive class
all_output_classes.extend([int(x) for x in batch_output_classes])
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
logits = [
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
for p in all_base_output_probabilities
]
scaled_logits = [ scaled_logits = [
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
for logit in logits for logit in logits
@ -338,7 +388,7 @@ def run_content_classification_inference(
ContentClassificationPrediction( ContentClassificationPrediction(
predicted_label=predicted_label, content_boost_factor=output_score predicted_label=predicted_label, content_boost_factor=output_score
) )
for predicted_label, output_score in zip(output_classes, prediction_scores) for predicted_label, output_score in zip(all_output_classes, prediction_scores)
] ]
return content_classification_predictions return content_classification_predictions
@ -494,9 +544,4 @@ async def process_analysis_request(
async def process_content_classification_request( async def process_content_classification_request(
content_classification_requests: list[str], content_classification_requests: list[str],
) -> list[ContentClassificationPrediction]: ) -> list[ContentClassificationPrediction]:
return run_content_classification_inference( return run_content_classification_inference(content_classification_requests)
[
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + req
for req in content_classification_requests
]
)

View File

@ -75,11 +75,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(f"Torch Threads: {torch.get_num_threads()}") logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY: if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model() warm_up_intent_model()
else: else:
logger.notice("This model server should only run document indexing.") logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
warm_up_information_content_model() )
warm_up_information_content_model()
yield yield

View File

@ -563,6 +563,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
access=doc_access, access=doc_access,
boost=doc.boost, boost=doc.boost,
hidden=doc.hidden, hidden=doc.hidden,
# aggregated_boost_factor=doc.aggregated_boost_factor,
) )
# update Vespa. OK if doc doesn't exist. Raises exception otherwise. # update Vespa. OK if doc doesn't exist. Raises exception otherwise.

View File

@ -60,8 +60,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.variable_functionality import global_version from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import MULTI_TENANT from shared_configs.configs import MULTI_TENANT
logger = setup_logger() logger = setup_logger()
@ -353,9 +351,7 @@ def _run_indexing(
callback=callback, callback=callback,
) )
information_content_classification_model = InformationContentClassificationModel( information_content_classification_model = InformationContentClassificationModel()
model_server_host=MODEL_SERVER_HOST, model_server_port=MODEL_SERVER_PORT
)
document_index = get_default_document_index( document_index = get_default_document_index(
index_attempt_start.search_settings, index_attempt_start.search_settings,

View File

@ -139,8 +139,3 @@ if _LITELLM_EXTRA_BODY_RAW:
USE_INFORMATION_CONTENT_CLASSIFICATION = ( USE_INFORMATION_CONTENT_CLASSIFICATION = (
os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true" os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true"
) )
# Cutoff below which we start using the information content classification model
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10
)

View File

@ -10,9 +10,6 @@ from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.constants import DEFAULT_BOOST from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.model_configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
@ -64,6 +61,9 @@ from onyx.natural_language_processing.search_nlp_models import (
from onyx.llm.factory import get_default_llm_with_vision from onyx.llm.factory import get_default_llm_with_vision
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time from onyx.utils.timing import log_function_time
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
logger = setup_logger() logger = setup_logger()
@ -147,7 +147,7 @@ def _upsert_documents_in_db(
def _get_aggregated_boost_factor( def _get_aggregated_boost_factor(
chunks: list[IndexChunk], chunks: list[IndexChunk],
information_content_classification_model: InformationContentClassificationModel, information_content_classification_model: InformationContentClassificationModel,
) -> tuple[list[IndexChunk], list[float], list[ConnectorFailure]]: ) -> list[float]:
"""Calculates the aggregated boost factor for a chunk based on its content.""" """Calculates the aggregated boost factor for a chunk based on its content."""
short_chunk_content_dict = { short_chunk_content_dict = {
@ -171,7 +171,7 @@ def _get_aggregated_boost_factor(
# Default to 1.0 for longer chunks, use predicted score for short chunks # Default to 1.0 for longer chunks, use predicted score for short chunks
chunk_content_scores = [score_map.get(i, 1.0) for i in range(len(chunks))] chunk_content_scores = [score_map.get(i, 1.0) for i in range(len(chunks))]
return chunks, chunk_content_scores, [] return chunk_content_scores
except Exception as e: except Exception as e:
logger.exception( logger.exception(
@ -180,7 +180,6 @@ def _get_aggregated_boost_factor(
chunks_with_scores: list[IndexChunk] = [] chunks_with_scores: list[IndexChunk] = []
chunk_content_scores = [] chunk_content_scores = []
failures: list[ConnectorFailure] = []
for chunk in chunks: for chunk in chunks:
if ( if (
@ -200,25 +199,15 @@ def _get_aggregated_boost_factor(
chunks_with_scores.append(chunk) chunks_with_scores.append(chunk)
except Exception as e: except Exception as e:
logger.exception( logger.exception(
f"Error predicting content classification for chunk: {e}. Adding to missed content classifications." f"Error predicting content classification for chunk: {e}."
) )
failures.append( raise Exception(
ConnectorFailure( f"Failed to predict content classification for chunk {chunk.chunk_id} "
failed_document=DocumentFailure( f"from document {chunk.source_document.id}"
document_id=chunk.source_document.id, ) from e
document_link=(
chunk.source_document.sections[0].link
if chunk.source_document.sections
else None
),
),
failure_message=str(e),
exception=e,
)
)
return chunks_with_scores, chunk_content_scores, failures return chunk_content_scores
def get_doc_ids_to_update( def get_doc_ids_to_update(
@ -619,11 +608,13 @@ def index_doc_batch(
chunk_content_scores, chunk_content_scores,
chunk_content_classification_failures, chunk_content_classification_failures,
) = ( ) = (
chunks_with_embeddings,
_get_aggregated_boost_factor( _get_aggregated_boost_factor(
chunks_with_embeddings, information_content_classification_model chunks_with_embeddings, information_content_classification_model
) )
if USE_INFORMATION_CONTENT_CLASSIFICATION if USE_INFORMATION_CONTENT_CLASSIFICATION
else (chunks_with_embeddings, [1.0] * len(chunks_with_embeddings), []) else [1.0] * len(chunks_with_embeddings),
embedding_failures,
) )
updatable_ids = [doc.id for doc in ctx.updatable_docs] updatable_ids = [doc.id for doc in ctx.updatable_docs]

View File

@ -85,14 +85,14 @@ class DocMetadataAwareIndexChunk(IndexChunk):
boost: influences the ranking of this chunk at query time. Positive -> ranked higher, boost: influences the ranking of this chunk at query time. Positive -> ranked higher,
negative -> ranked lower. Not included in aggregated boost calculation negative -> ranked lower. Not included in aggregated boost calculation
for legacy reasons. for legacy reasons.
aggregated_boost_factor: represents non-user-specific aggregated boost calculation aggregated_boost_factor: represents the content information boost calculation
""" """
tenant_id: str tenant_id: str
access: "DocumentAccess" access: "DocumentAccess"
document_sets: set[str] document_sets: set[str]
boost: int boost: int
aggregated_boost_factor: float = 1.0 aggregated_boost_factor: float
@classmethod @classmethod
def from_index_chunk( def from_index_chunk(

View File

@ -29,6 +29,8 @@ from onyx.natural_language_processing.exceptions import (
from onyx.natural_language_processing.utils import get_tokenizer from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbeddingProvider
@ -382,8 +384,8 @@ class QueryAnalysisModel:
class InformationContentClassificationModel: class InformationContentClassificationModel:
def __init__( def __init__(
self, self,
model_server_host: str = MODEL_SERVER_HOST, model_server_host: str = INDEXING_MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT, model_server_port: int = INDEXING_MODEL_SERVER_PORT,
) -> None: ) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port) model_server_url = build_model_server_url(model_server_host, model_server_port)
self.content_server_endpoint = ( self.content_server_endpoint = (
@ -397,11 +399,11 @@ class InformationContentClassificationModel:
response = requests.post(self.content_server_endpoint, json=queries) response = requests.post(self.content_server_endpoint, json=queries)
response.raise_for_status() response.raise_for_status()
response_model = InformationContentClassificationResponses( model_responses = InformationContentClassificationResponses(
information_content_classifications=response.json() information_content_classifications=response.json()
) )
return response_model.information_content_classifications return model_responses.information_content_classifications
class ConnectorClassificationModel: class ConnectorClassificationModel:

View File

@ -98,6 +98,7 @@ def _create_indexable_chunks(
boost=DEFAULT_BOOST, boost=DEFAULT_BOOST,
large_chunk_id=None, large_chunk_id=None,
image_file_name=None, image_file_name=None,
aggregated_boost_factor=1.0,
) )
chunks.append(chunk) chunks.append(chunk)

View File

@ -26,8 +26,6 @@ from onyx.server.onyx_api.models import DocMinimalInfo
from onyx.server.onyx_api.models import IngestionDocument from onyx.server.onyx_api.models import IngestionDocument
from onyx.server.onyx_api.models import IngestionResult from onyx.server.onyx_api.models import IngestionResult
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.contextvars import get_current_tenant_id from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger() logger = setup_logger()
@ -107,9 +105,7 @@ def upsert_ingestion_doc(
search_settings=search_settings search_settings=search_settings
) )
information_content_classification_model = InformationContentClassificationModel( information_content_classification_model = InformationContentClassificationModel()
model_server_host=MODEL_SERVER_HOST, model_server_port=MODEL_SERVER_PORT
)
indexing_pipeline = build_indexing_pipeline( indexing_pipeline = build_indexing_pipeline(
embedder=index_embedding_model, embedder=index_embedding_model,

View File

@ -291,3 +291,8 @@ INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float(
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float( INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0 os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
) )
# Cutoff below which we start using the information content classification model
# (cutoff length number itself is still considered 'short'))
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10
)

View File

@ -0,0 +1,145 @@
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
import numpy as np
import numpy.typing as npt
import pytest
from model_server.custom_models import run_content_classification_inference
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.model_server_models import ContentClassificationPrediction
@pytest.fixture
def mock_content_model() -> Mock:
model = Mock()
# Create actual numpy arrays for the mock returns
predict_output = np.array(
[1, 0] * 50, dtype=np.int64
) # Pre-allocate enough elements
proba_output = np.array(
[[0.3, 0.7], [0.7, 0.3]] * 50, dtype=np.float64
) # Pre-allocate enough elements
# Create a mock tensor that has a numpy method and supports indexing
class MockTensor:
def __init__(self, value: npt.NDArray[Any]) -> None:
self.value = value
def numpy(self) -> npt.NDArray[Any]:
return self.value
def __getitem__(self, idx: Any) -> Any:
result = self.value[idx]
# Wrap scalar values back in MockTensor
if isinstance(result, (np.float64, np.int64)):
return MockTensor(np.array([result]))
return MockTensor(result)
# Mock the direct call to return a MockTensor for each input
def model_call(inputs: list[str]) -> list[MockTensor]:
batch_size = len(inputs)
return [MockTensor(predict_output[i : i + 1]) for i in range(batch_size)]
model.side_effect = model_call
# Mock predict_proba to return MockTensor-wrapped numpy array
def predict_proba_call(x: list[str]) -> MockTensor:
batch_size = len(x)
return MockTensor(proba_output[:batch_size])
model.predict_proba.side_effect = predict_proba_call
return model
@patch("model_server.custom_models.get_local_information_content_model")
def test_run_content_classification_inference(
mock_get_model: Mock,
mock_content_model: Mock,
) -> None:
"""
Test the content classification inference function.
Verifies that the function correctly processes text inputs and returns appropriate predictions.
"""
# Setup
mock_get_model.return_value = mock_content_model
test_inputs = [
"Imagine a short text with content",
"Imagine a short text without content",
"x "
* (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1
), # Long input that exceeds maximal length for when the model should be applied
"", # Empty input
]
# Execute
results = run_content_classification_inference(test_inputs)
# Assert
assert len(results) == len(test_inputs)
assert all(isinstance(r, ContentClassificationPrediction) for r in results)
# Check each prediction has expected attributes and ranges
for result_num, result in enumerate(results):
assert hasattr(result, "predicted_label")
assert hasattr(result, "content_boost_factor")
assert isinstance(result.predicted_label, int)
assert isinstance(result.content_boost_factor, float)
assert (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
<= result.content_boost_factor
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
)
if result_num == 2:
assert (
result.content_boost_factor
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
)
assert result.predicted_label == 1
elif result_num == 3:
assert (
result.content_boost_factor
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
)
assert result.predicted_label == 0
# Verify model handling of long inputs
mock_content_model.predict_proba.reset_mock()
long_input = ["x " * 1000] # Definitely exceeds MAX_LENGTH
results = run_content_classification_inference(long_input)
assert len(results) == 1
assert (
mock_content_model.predict_proba.call_count == 0
) # Should skip model call for too-long input
@patch("model_server.custom_models.get_local_information_content_model")
def test_batch_processing(
mock_get_model: Mock,
mock_content_model: Mock,
) -> None:
"""
Test that the function correctly handles batch processing of inputs.
"""
# Setup
mock_get_model.return_value = mock_content_model
# Create test input larger than batch size
test_inputs = [f"Test input {i}" for i in range(40)] # > BATCH_SIZE (32)
# Execute
results = run_content_classification_inference(test_inputs)
# Assert
assert len(results) == 40
# Verify batching occurred (should have called predict_proba twice)
assert mock_content_model.predict_proba.call_count == 2

View File

@ -1,12 +1,24 @@
from typing import cast from typing import cast
from typing import List from typing import List
from unittest.mock import Mock
import pytest
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.connectors.models import Document from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource from onyx.connectors.models import DocumentSource
from onyx.connectors.models import ImageSection from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection from onyx.connectors.models import TextSection
from onyx.indexing.indexing_pipeline import _get_aggregated_boost_factor
from onyx.indexing.indexing_pipeline import filter_documents from onyx.indexing.indexing_pipeline import filter_documents
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import IndexChunk
from onyx.natural_language_processing.search_nlp_models import (
ContentClassificationPrediction,
)
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
def create_test_document( def create_test_document(
@ -123,3 +135,117 @@ def test_filter_documents_multiple_documents() -> None:
def test_filter_documents_empty_batch() -> None: def test_filter_documents_empty_batch() -> None:
result = filter_documents([]) result = filter_documents([])
assert len(result) == 0 assert len(result) == 0
# Tests for get_aggregated_boost_factor
def create_test_chunk(
content: str, chunk_id: int = 0, doc_id: str = "test_doc"
) -> IndexChunk:
doc = Document(
id=doc_id,
semantic_identifier="test doc",
sections=[],
source=DocumentSource.FILE,
metadata={},
)
return IndexChunk(
chunk_id=chunk_id,
content=content,
source_document=doc,
blurb=content[:50], # First 50 chars as blurb
source_links={0: "test_link"},
section_continuation=False,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
large_chunk_reference_ids=[],
embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]),
title_embedding=None,
image_file_name=None,
)
def test_get_aggregated_boost_factor() -> None:
# Create test chunks - mix of short and long content
chunks = [
create_test_chunk("Short content", 0),
create_test_chunk(
"Long " * (INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1), 1
),
create_test_chunk("Another short chunk", 2),
]
# Mock the classification model
mock_model = Mock()
mock_model.predict.return_value = [
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8),
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.9),
]
# Execute the function
boost_scores = _get_aggregated_boost_factor(
chunks=chunks, information_content_classification_model=mock_model
)
# Assertions
assert len(boost_scores) == 3
# Check that long content got default boost
assert boost_scores[1] == 1.0
# Check that short content got predicted boosts
assert boost_scores[0] == 0.8
assert boost_scores[2] == 0.9
# Verify model was only called once with the short chunks
mock_model.predict.assert_called_once()
assert len(mock_model.predict.call_args[0][0]) == 2
def test_get_aggregated_boost_factorilure() -> None:
chunks = [
create_test_chunk("Short content 1", 0),
create_test_chunk("Short content 2", 1),
]
# Mock model to fail on batch prediction but succeed on individual predictions
mock_model = Mock()
mock_model.predict.side_effect = [
Exception("Batch prediction failed"), # First call fails
[
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.7)
], # Individual calls succeed
[ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8)],
]
# Execute
boost_scores = _get_aggregated_boost_factor(
chunks=chunks, information_content_classification_model=mock_model
)
# Assertions
assert len(boost_scores) == 2
assert boost_scores == [0.7, 0.8]
def test_get_aggregated_boost_factor_individual_failure() -> None:
chunks = [
create_test_chunk("Short content", 0),
create_test_chunk("Short content", 1),
]
# Mock model to fail on both batch and individual prediction
mock_model = Mock()
mock_model.predict.side_effect = Exception("Prediction failed")
# Execute and verify it raises an exception
with pytest.raises(Exception) as exc_info:
_get_aggregated_boost_factor(
chunks=chunks, information_content_classification_model=mock_model
)
assert "Failed to predict content classification for chunk" in str(exc_info.value)