Reduce ranking scores for short chunks without actual information (#4098)

* remove title for slack

* initial working code

* simplification

* improvements

* name change to information_content_model

* avoid boost_score > 1.0

* nit

* EL comments and improvements

Improvements:
  - proper import of information content model from cache or HF
  - warm up for information content model

Other:
  - EL PR review comments

* nit

* requirements version update

* fixed docker file

* new home for model_server configs

* default off

* small updates

* YS comments - pt 1

* renaming to chunk_boost & chunk table def

* saving and deleting chunk stats in new table

* saving and updating chunk stats

* improved dict score update

* create columns for individual boost factors

* RK comments

* Update migration

* manual import reordering
This commit is contained in:
joachim-danswer 2025-03-13 10:35:45 -07:00 committed by GitHub
parent ba82888e1e
commit 463340b8a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 898 additions and 34 deletions

View File

@ -31,7 +31,8 @@ RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \

View File

@ -0,0 +1,51 @@
"""add chunk stats table
Revision ID: 3781a5eb12cb
Revises: df46c75b714e
Create Date: 2025-03-10 10:02:30.586666
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3781a5eb12cb"
down_revision = "df46c75b714e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chunk_stats",
sa.Column("id", sa.String(), primary_key=True, index=True),
sa.Column(
"document_id",
sa.String(),
sa.ForeignKey("document.id"),
nullable=False,
index=True,
),
sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False),
sa.Column("information_content_boost", sa.Float(), nullable=True),
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
index=True,
server_default=sa.func.now(),
),
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True),
sa.UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
op.create_index(
"ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"]
)
def downgrade() -> None:
op.drop_index("ix_chunk_sync_status", table_name="chunk_stats")
op.drop_table("chunk_stats")

View File

@ -3,6 +3,7 @@ from shared_configs.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"

View File

@ -1,11 +1,14 @@
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from setfit import SetFitModel # type: ignore[import]
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.onyx_torch_model import ConnectorClassifier
from model_server.onyx_torch_model import HybridClassifier
@ -13,11 +16,22 @@ from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
)
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
@ -31,6 +45,10 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
@ -85,7 +103,7 @@ def get_intent_model_tokenizer() -> AutoTokenizer:
def get_local_intent_model(
model_name_or_path: str = INTENT_MODEL_VERSION,
tag: str = INTENT_MODEL_TAG,
tag: str | None = INTENT_MODEL_TAG,
) -> HybridClassifier:
global _INTENT_MODEL
if _INTENT_MODEL is None:
@ -102,7 +120,9 @@ def get_local_intent_model(
try:
# Attempt to download the model snapshot
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
except Exception as e:
logger.error(
@ -112,6 +132,44 @@ def get_local_intent_model(
return _INTENT_MODEL
def get_local_information_content_model(
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
) -> SetFitModel:
global _INFORMATION_CONTENT_MODEL
if _INFORMATION_CONTENT_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
logger.notice(
f"Loading content information model from local cache: {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
logger.notice(
f"Loaded content information model from local cache: {local_path}"
)
except Exception as e:
logger.warning(f"Failed to load content information model directly: {e}")
try:
# Attempt to download the model snapshot
logger.notice(
f"Downloading content information model snapshot for {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
except Exception as e:
logger.error(
f"Failed to load content information model even after attempted snapshot download: {e}"
)
raise
return _INFORMATION_CONTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
@ -195,6 +253,13 @@ def warm_up_intent_model() -> None:
)
def warm_up_information_content_model() -> None:
logger.notice("Warming up Content Model") # TODO: add version if needed
information_content_model = get_local_information_content_model()
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
@ -218,6 +283,117 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
return intent_probabilities.tolist(), token_positive_probs
@simple_log_function_time()
def run_content_classification_inference(
text_inputs: list[str],
) -> list[ContentClassificationPrediction]:
"""
Assign a score to the segments in question. The model stored in get_local_information_content_model()
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
In the code outside of the model/inference model servers that score will be converted into the actual
boost factor.
"""
def _prob_to_score(prob: float) -> float:
"""
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
"""
_MIN_BASE_SCORE = 0.25
_MAX_BASE_SCORE = 0.75
if prob < _MIN_BASE_SCORE:
raw_score = 0.0
elif prob < _MAX_BASE_SCORE:
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
else:
raw_score = 1.0
return (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
+ (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
)
* raw_score
)
_BATCH_SIZE = 32
content_model = get_local_information_content_model()
# Process inputs in batches
all_output_classes: list[int] = []
all_base_output_probabilities: list[float] = []
for i in range(0, len(text_inputs), _BATCH_SIZE):
batch = text_inputs[i : i + _BATCH_SIZE]
batch_with_prefix = []
batch_indices = []
# Pre-allocate results for this batch
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
# Pre-process batch to handle long input exceptions
for j, text in enumerate(batch):
if len(text) == 0:
# if no input, treat as non-informative from the model's perspective
batch_output_classes[j] = np.array(0)
batch_probabilities[j] = np.array(0.0)
logger.warning("Input for Content Information Model is empty")
elif (
len(text.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
# if input is short, use the model
batch_with_prefix.append(
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
)
batch_indices.append(j)
else:
# if longer than cutoff, treat as informative (stay with default), but issue warning
logger.warning("Input for Content Information Model too long")
if batch_with_prefix: # Only run model if we have valid inputs
# Get predictions for the batch
model_output_classes = content_model(batch_with_prefix)
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
# Place results in the correct positions
for idx, batch_idx in enumerate(batch_indices):
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
batch_probabilities[batch_idx] = model_output_probabilities[idx][
1
].numpy() # x[1] is prob of the positive class
all_output_classes.extend([int(x) for x in batch_output_classes])
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
logits = [
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
for p in all_base_output_probabilities
]
scaled_logits = [
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
for logit in logits
]
output_probabilities_with_temp = [
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
for scaled_logit in scaled_logits
]
prediction_scores = [
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
]
content_classification_predictions = [
ContentClassificationPrediction(
predicted_label=predicted_label, content_boost_factor=output_score
)
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
]
return content_classification_predictions
def map_keywords(
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
) -> list[str]:
@ -362,3 +538,10 @@ async def process_analysis_request(
is_keyword, keywords = run_analysis(intent_request)
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
@router.post("/content-classification")
async def process_content_classification_request(
content_classification_requests: list[str],
) -> list[ContentClassificationPrediction]:
return run_content_classification_inference(content_classification_requests)

View File

@ -13,6 +13,7 @@ from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
@ -74,9 +75,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model()
else:
logger.notice("This model server should only run document indexing.")
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
yield

View File

@ -563,6 +563,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
# aggregated_boost_factor=doc.aggregated_boost_factor,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.

View File

@ -53,6 +53,9 @@ from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
@ -348,6 +351,8 @@ def _run_indexing(
callback=callback,
)
information_content_classification_model = InformationContentClassificationModel()
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
@ -356,6 +361,7 @@ def _run_indexing(
indexing_pipeline = build_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning

View File

@ -132,3 +132,10 @@ if _LITELLM_EXTRA_BODY_RAW:
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
except Exception:
pass
# Whether and how to lower scores for short chunks w/o relevant context
# Evaluated via custom ML model
USE_INFORMATION_CONTENT_CLASSIFICATION = (
os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true"
)

View File

@ -220,7 +220,6 @@ def thread_to_doc(
source=DocumentSource.SLACK,
semantic_identifier=doc_sem_id,
doc_updated_at=get_latest_message_time(thread),
title="", # slack docs don't really have a "title"
primary_owners=valid_experts,
metadata={"Channel": channel["name"]},
)

63
backend/onyx/db/chunk.py Normal file
View File

@ -0,0 +1,63 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy.orm import Session
from onyx.db.models import ChunkStats
from onyx.indexing.models import UpdatableChunkData
def update_chunk_boost_components__no_commit(
chunk_data: list[UpdatableChunkData],
db_session: Session,
) -> None:
"""Updates the chunk_boost_components for chunks in the database.
Args:
chunk_data: List of dicts containing chunk_id, document_id, and boost_score
db_session: SQLAlchemy database session
"""
if not chunk_data:
return
for data in chunk_data:
chunk_in_doc_id = int(data.chunk_id)
if chunk_in_doc_id < 0:
raise ValueError(f"Chunk ID is empty for chunk {data}")
chunk_document_id = f"{data.document_id}" f"__{chunk_in_doc_id}"
chunk_stats = (
db_session.query(ChunkStats)
.filter(
ChunkStats.id == chunk_document_id,
)
.first()
)
score = data.boost_score
if chunk_stats:
chunk_stats.information_content_boost = score
chunk_stats.last_modified = datetime.now(timezone.utc)
db_session.add(chunk_stats)
else:
# do not save new chunks with a neutral boost score
if score == 1.0:
continue
# Create new record
chunk_stats = ChunkStats(
document_id=data.document_id,
chunk_in_doc_id=chunk_in_doc_id,
information_content_boost=score,
)
db_session.add(chunk_stats)
def delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This deletes just chunk stats in postgres."""
stmt = delete(ChunkStats).where(ChunkStats.document_id.in_(document_ids))
db_session.execute(stmt)

View File

@ -23,6 +23,7 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
@ -562,6 +563,18 @@ def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
# Start by deleting the chunk stats for the documents
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session

View File

@ -591,6 +591,55 @@ class Document(Base):
)
class ChunkStats(Base):
__tablename__ = "chunk_stats"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
# this should correspond to the ID of the document
# (as is passed around in Onyx)
id: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
default=lambda context: (
f"{context.get_current_parameters()['document_id']}"
f"__{context.get_current_parameters()['chunk_in_doc_id']}"
),
index=True,
)
# Reference to parent document
document_id: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("document.id"), nullable=False, index=True
)
chunk_in_doc_id: Mapped[int] = mapped_column(
Integer,
nullable=False,
)
information_content_boost: Mapped[float | None] = mapped_column(
Float, nullable=True
)
last_modified: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=False, index=True, default=func.now()
)
last_synced: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, index=True
)
__table_args__ = (
Index(
"ix_chunk_sync_status",
last_modified,
last_synced,
),
UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
class Tag(Base):
__tablename__ = "tag"

View File

@ -101,6 +101,7 @@ class VespaDocumentFields:
document_sets: set[str] | None = None
boost: float | None = None
hidden: bool | None = None
aggregated_chunk_boost_factor: float | None = None
@dataclass

View File

@ -80,6 +80,11 @@ schema DANSWER_CHUNK_NAME {
indexing: summary | attribute
rank: filter
}
# Field to indicate whether a short chunk is a low content chunk
field aggregated_chunk_boost_factor type float {
indexing: attribute
}
# Needs to have a separate Attribute list for efficient filtering
field metadata_list type array<string> {
indexing: summary | attribute
@ -142,6 +147,11 @@ schema DANSWER_CHUNK_NAME {
expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0)
}
function inline aggregated_chunk_boost() {
# Aggregated boost factor, currently only used for information content classification
expression: if(isNan(attribute(aggregated_chunk_boost_factor)) == 1, 1.0, attribute(aggregated_chunk_boost_factor))
}
# Document score decays from 1 to 0.75 as age of last updated time increases
function inline recency_bias() {
expression: max(1 / (1 + query(decay_factor) * document_age), 0.75)
@ -199,6 +209,8 @@ schema DANSWER_CHUNK_NAME {
* document_boost
# Decay factor based on time document was last updated
* recency_bias
# Boost based on aggregated boost calculation
* aggregated_chunk_boost
}
rerank-count: 1000
}
@ -210,6 +222,7 @@ schema DANSWER_CHUNK_NAME {
closeness(field, embeddings)
document_boost
recency_bias
aggregated_chunk_boost
closest(embeddings)
}
}

View File

@ -22,6 +22,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_ID
@ -201,6 +202,7 @@ def _index_vespa_chunk(
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
IMAGE_FILE_NAME: chunk.image_file_name,
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}
if multitenant:

View File

@ -72,6 +72,7 @@ METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
BOOST = "boost"
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
PRIMARY_OWNERS = "primary_owners"
SECONDARY_OWNERS = "secondary_owners"
@ -97,6 +98,7 @@ YQL_BASE = (
f"{SECTION_CONTINUATION}, "
f"{IMAGE_FILE_NAME}, "
f"{BOOST}, "
f"{AGGREGATED_CHUNK_BOOST_FACTOR}, "
f"{HIDDEN}, "
f"{DOC_UPDATED_AT}, "
f"{PRIMARY_OWNERS}, "

View File

@ -11,6 +11,7 @@ from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
@ -22,6 +23,7 @@ from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.db.chunk import update_chunk_boost_components__no_commit
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
@ -52,10 +54,19 @@ from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
logger = setup_logger()
@ -136,6 +147,72 @@ def _upsert_documents_in_db(
)
def _get_aggregated_chunk_boost_factor(
chunks: list[IndexChunk],
information_content_classification_model: InformationContentClassificationModel,
) -> list[float]:
"""Calculates the aggregated boost factor for a chunk based on its content."""
short_chunk_content_dict = {
chunk_num: chunk.content
for chunk_num, chunk in enumerate(chunks)
if len(chunk.content.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
}
short_chunk_contents = list(short_chunk_content_dict.values())
short_chunk_keys = list(short_chunk_content_dict.keys())
try:
predictions = information_content_classification_model.predict(
short_chunk_contents
)
# Create a mapping of chunk positions to their scores
score_map = {
short_chunk_keys[i]: prediction.content_boost_factor
for i, prediction in enumerate(predictions)
}
# Default to 1.0 for longer chunks, use predicted score for short chunks
chunk_content_scores = [score_map.get(i, 1.0) for i in range(len(chunks))]
return chunk_content_scores
except Exception as e:
logger.exception(
f"Error predicting content classification for chunks: {e}. Falling back to individual examples."
)
chunks_with_scores: list[IndexChunk] = []
chunk_content_scores = []
for chunk in chunks:
if (
len(chunk.content.split())
> INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
chunk_content_scores.append(1.0)
chunks_with_scores.append(chunk)
continue
try:
chunk_content_scores.append(
information_content_classification_model.predict([chunk.content])[
0
].content_boost_factor
)
chunks_with_scores.append(chunk)
except Exception as e:
logger.exception(
f"Error predicting content classification for chunk: {e}."
)
raise Exception(
f"Failed to predict content classification for chunk {chunk.chunk_id} "
f"from document {chunk.source_document.id}"
) from e
return chunk_content_scores
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
@ -165,6 +242,7 @@ def index_doc_batch_with_handler(
*,
chunker: Chunker,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
@ -176,6 +254,7 @@ def index_doc_batch_with_handler(
index_pipeline_result = index_doc_batch(
chunker=chunker,
embedder=embedder,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
document_batch=document_batch,
index_attempt_metadata=index_attempt_metadata,
@ -450,6 +529,7 @@ def index_doc_batch(
document_batch: list[Document],
chunker: Chunker,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
@ -526,7 +606,23 @@ def index_doc_batch(
else ([], [])
)
chunk_content_scores = (
_get_aggregated_chunk_boost_factor(
chunks_with_embeddings, information_content_classification_model
)
if USE_INFORMATION_CONTENT_CLASSIFICATION
else [1.0] * len(chunks_with_embeddings)
)
updatable_ids = [doc.id for doc in ctx.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk.chunk_id,
document_id=chunk.source_document.id,
boost_score=score,
)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
@ -579,8 +675,9 @@ def index_doc_batch(
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk in chunks_with_embeddings
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
logger.debug(
@ -665,6 +762,11 @@ def index_doc_batch(
db_session=db_session,
)
# save the chunk boost components to postgres
update_chunk_boost_components__no_commit(
chunk_data=updatable_chunk_data, db_session=db_session
)
db_session.commit()
result = IndexingPipelineResult(
@ -680,6 +782,7 @@ def index_doc_batch(
def build_indexing_pipeline(
*,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
db_session: Session,
tenant_id: str,
@ -703,6 +806,7 @@ def build_indexing_pipeline(
index_doc_batch_with_handler,
chunker=chunker,
embedder=embedder,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=ignore_time_skip,
db_session=db_session,

View File

@ -83,13 +83,16 @@ class DocMetadataAwareIndexChunk(IndexChunk):
document_sets: all document sets the source document for this chunk is a part
of. This is used for filtering / personas.
boost: influences the ranking of this chunk at query time. Positive -> ranked higher,
negative -> ranked lower.
negative -> ranked lower. Not included in aggregated boost calculation
for legacy reasons.
aggregated_chunk_boost_factor: represents the aggregated chunk-level boost (currently: information content)
"""
tenant_id: str
access: "DocumentAccess"
document_sets: set[str]
boost: int
aggregated_chunk_boost_factor: float
@classmethod
def from_index_chunk(
@ -98,6 +101,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
boost: int,
aggregated_chunk_boost_factor: float,
tenant_id: str,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.model_dump()
@ -106,6 +110,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
boost=boost,
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
tenant_id=tenant_id,
)
@ -179,3 +184,9 @@ class IndexingSetting(EmbeddingModelDetail):
class MultipassConfig(BaseModel):
multipass_indexing: bool
enable_large_chunks: bool
class UpdatableChunkData(BaseModel):
chunk_id: int
document_id: str
boost_score: float

View File

@ -29,6 +29,8 @@ from onyx.natural_language_processing.exceptions import (
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
@ -36,9 +38,11 @@ from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import InformationContentClassificationResponses
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
@ -377,6 +381,31 @@ class QueryAnalysisModel:
return response_model.is_keyword, response_model.keywords
class InformationContentClassificationModel:
def __init__(
self,
model_server_host: str = INDEXING_MODEL_SERVER_HOST,
model_server_port: int = INDEXING_MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.content_server_endpoint = (
model_server_url + "/custom/content-classification"
)
def predict(
self,
queries: list[str],
) -> list[ContentClassificationPrediction]:
response = requests.post(self.content_server_endpoint, json=queries)
response.raise_for_status()
model_responses = InformationContentClassificationResponses(
information_content_classifications=response.json()
)
return model_responses.information_content_classifications
class ConnectorClassificationModel:
def __init__(
self,

View File

@ -98,6 +98,7 @@ def _create_indexable_chunks(
boost=DEFAULT_BOOST,
large_chunk_id=None,
image_file_name=None,
aggregated_chunk_boost_factor=1.0,
)
chunks.append(chunk)

View File

@ -19,6 +19,9 @@ from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.server.onyx_api.models import DocMinimalInfo
from onyx.server.onyx_api.models import IngestionDocument
from onyx.server.onyx_api.models import IngestionResult
@ -102,8 +105,11 @@ def upsert_ingestion_doc(
search_settings=search_settings
)
information_content_classification_model = InformationContentClassificationModel()
indexing_pipeline = build_indexing_pipeline(
embedder=index_embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=curr_doc_index,
ignore_time_skip=True,
db_session=db_session,
@ -138,6 +144,7 @@ def upsert_ingestion_doc(
sec_ind_pipeline = build_indexing_pipeline(
embedder=new_index_embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=sec_doc_index,
ignore_time_skip=True,
db_session=db_session,

View File

@ -25,7 +25,7 @@ google-auth-oauthlib==1.0.0
httpcore==1.0.5
httpx[http2]==0.27.0
httpx-oauth==0.15.1
huggingface-hub==0.20.1
huggingface-hub==0.29.0
inflection==0.5.1
jira==3.5.1
jsonref==1.1.0
@ -71,6 +71,7 @@ requests==2.32.2
requests-oauthlib==1.3.1
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
rfc3986==1.5.0
setfit==1.1.1
simple-salesforce==1.12.6
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
@ -78,7 +79,7 @@ starlette==0.36.3
supervisor==4.2.5
tiktoken==0.7.0
timeago==1.0.16
transformers==4.39.2
transformers==4.49.0
unstructured==0.15.1
unstructured-client==0.25.4
uvicorn==0.21.1

View File

@ -14,7 +14,7 @@ pytest-asyncio==0.22.0
pytest==7.4.4
reorder-python-imports==3.9.0
ruff==0.0.286
sentence-transformers==2.6.1
sentence-transformers==3.4.1
trafilatura==1.12.2
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13

View File

@ -7,9 +7,10 @@ openai==1.61.0
pydantic==2.8.2
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1
sentence-transformers==3.4.1
setfit==1.1.1
torch==2.2.0
transformers==4.39.2
transformers==4.49.0
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.61.16

View File

@ -161,17 +161,21 @@ overview_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/overview",
title=overview_title,
content=overview,
title_embedding=model.encode(f"search_document: {overview_title}"),
content_embedding=model.encode(f"search_document: {overview_title}\n{overview}"),
title_embedding=list(model.encode(f"search_document: {overview_title}")),
content_embedding=list(
model.encode(f"search_document: {overview_title}\n{overview}")
),
)
enterprise_search_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/enterprise_search",
title=enterprise_search_title,
content=enterprise_search_1,
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
content_embedding=model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
content_embedding=list(
model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
)
),
)
@ -179,9 +183,11 @@ enterprise_search_doc_2 = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/enterprise_search",
title=enterprise_search_title,
content=enterprise_search_2,
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
content_embedding=model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
content_embedding=list(
model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
)
),
chunk_ind=1,
)
@ -190,9 +196,9 @@ ai_platform_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/ai_platform",
title=ai_platform_title,
content=ai_platform,
title_embedding=model.encode(f"search_document: {ai_platform_title}"),
content_embedding=model.encode(
f"search_document: {ai_platform_title}\n{ai_platform}"
title_embedding=list(model.encode(f"search_document: {ai_platform_title}")),
content_embedding=list(
model.encode(f"search_document: {ai_platform_title}\n{ai_platform}")
),
)
@ -200,9 +206,9 @@ customer_support_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/support",
title=customer_support_title,
content=customer_support,
title_embedding=model.encode(f"search_document: {customer_support_title}"),
content_embedding=model.encode(
f"search_document: {customer_support_title}\n{customer_support}"
title_embedding=list(model.encode(f"search_document: {customer_support_title}")),
content_embedding=list(
model.encode(f"search_document: {customer_support_title}\n{customer_support}")
),
)
@ -210,17 +216,17 @@ sales_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/sales",
title=sales_title,
content=sales,
title_embedding=model.encode(f"search_document: {sales_title}"),
content_embedding=model.encode(f"search_document: {sales_title}\n{sales}"),
title_embedding=list(model.encode(f"search_document: {sales_title}")),
content_embedding=list(model.encode(f"search_document: {sales_title}\n{sales}")),
)
operations_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/operations",
title=operations_title,
content=operations,
title_embedding=model.encode(f"search_document: {operations_title}"),
content_embedding=model.encode(
f"search_document: {operations_title}\n{operations}"
title_embedding=list(model.encode(f"search_document: {operations_title}")),
content_embedding=list(
model.encode(f"search_document: {operations_title}\n{operations}")
),
)

View File

@ -99,6 +99,7 @@ def generate_dummy_chunk(
),
document_sets={document_set for document_set in document_set_names},
boost=random.randint(-1, 1),
aggregated_chunk_boost_factor=random.random(),
tenant_id=POSTGRES_DEFAULT_SCHEMA,
)

View File

@ -23,9 +23,11 @@ INDEXING_MODEL_SERVER_PORT = int(
# Onyx custom Deep Learning Models
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
INTENT_MODEL_TAG = "v1.0.3"
INTENT_MODEL_VERSION = "onyx-dot-app/hybrid-intent-token-classifier"
# INTENT_MODEL_TAG = "v1.0.3"
INTENT_MODEL_TAG: str | None = None
INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model"
INFORMATION_CONTENT_MODEL_TAG: str | None = None
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512
@ -277,3 +279,20 @@ SUPPORTED_EMBEDDING_MODELS = [
index_name="danswer_chunk_intfloat_multilingual_e5_small",
),
]
# Maximum (least severe) downgrade factor for chunks above the cutoff
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX") or 1.0
)
# Minimum (most severe) downgrade factor for short chunks below the cutoff if no content
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.7
)
# Temperature for the information content classification model
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
)
# Cutoff below which we start using the information content classification model
# (cutoff length number itself is still considered 'short'))
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10
)

View File

@ -4,6 +4,7 @@ from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
Embedding = list[float]
@ -73,7 +74,20 @@ class IntentResponse(BaseModel):
keywords: list[str]
class InformationContentClassificationRequests(BaseModel):
queries: list[str]
class SupportedEmbeddingModel(BaseModel):
name: str
dim: int
index_name: str
class ContentClassificationPrediction(BaseModel):
predicted_label: int
content_boost_factor: float
class InformationContentClassificationResponses(BaseModel):
information_content_classifications: list[ContentClassificationPrediction]

View File

@ -0,0 +1,145 @@
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
import numpy as np
import numpy.typing as npt
import pytest
from model_server.custom_models import run_content_classification_inference
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.model_server_models import ContentClassificationPrediction
@pytest.fixture
def mock_content_model() -> Mock:
model = Mock()
# Create actual numpy arrays for the mock returns
predict_output = np.array(
[1, 0] * 50, dtype=np.int64
) # Pre-allocate enough elements
proba_output = np.array(
[[0.3, 0.7], [0.7, 0.3]] * 50, dtype=np.float64
) # Pre-allocate enough elements
# Create a mock tensor that has a numpy method and supports indexing
class MockTensor:
def __init__(self, value: npt.NDArray[Any]) -> None:
self.value = value
def numpy(self) -> npt.NDArray[Any]:
return self.value
def __getitem__(self, idx: Any) -> Any:
result = self.value[idx]
# Wrap scalar values back in MockTensor
if isinstance(result, (np.float64, np.int64)):
return MockTensor(np.array([result]))
return MockTensor(result)
# Mock the direct call to return a MockTensor for each input
def model_call(inputs: list[str]) -> list[MockTensor]:
batch_size = len(inputs)
return [MockTensor(predict_output[i : i + 1]) for i in range(batch_size)]
model.side_effect = model_call
# Mock predict_proba to return MockTensor-wrapped numpy array
def predict_proba_call(x: list[str]) -> MockTensor:
batch_size = len(x)
return MockTensor(proba_output[:batch_size])
model.predict_proba.side_effect = predict_proba_call
return model
@patch("model_server.custom_models.get_local_information_content_model")
def test_run_content_classification_inference(
mock_get_model: Mock,
mock_content_model: Mock,
) -> None:
"""
Test the content classification inference function.
Verifies that the function correctly processes text inputs and returns appropriate predictions.
"""
# Setup
mock_get_model.return_value = mock_content_model
test_inputs = [
"Imagine a short text with content",
"Imagine a short text without content",
"x "
* (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1
), # Long input that exceeds maximal length for when the model should be applied
"", # Empty input
]
# Execute
results = run_content_classification_inference(test_inputs)
# Assert
assert len(results) == len(test_inputs)
assert all(isinstance(r, ContentClassificationPrediction) for r in results)
# Check each prediction has expected attributes and ranges
for result_num, result in enumerate(results):
assert hasattr(result, "predicted_label")
assert hasattr(result, "content_boost_factor")
assert isinstance(result.predicted_label, int)
assert isinstance(result.content_boost_factor, float)
assert (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
<= result.content_boost_factor
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
)
if result_num == 2:
assert (
result.content_boost_factor
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
)
assert result.predicted_label == 1
elif result_num == 3:
assert (
result.content_boost_factor
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
)
assert result.predicted_label == 0
# Verify model handling of long inputs
mock_content_model.predict_proba.reset_mock()
long_input = ["x " * 1000] # Definitely exceeds MAX_LENGTH
results = run_content_classification_inference(long_input)
assert len(results) == 1
assert (
mock_content_model.predict_proba.call_count == 0
) # Should skip model call for too-long input
@patch("model_server.custom_models.get_local_information_content_model")
def test_batch_processing(
mock_get_model: Mock,
mock_content_model: Mock,
) -> None:
"""
Test that the function correctly handles batch processing of inputs.
"""
# Setup
mock_get_model.return_value = mock_content_model
# Create test input larger than batch size
test_inputs = [f"Test input {i}" for i in range(40)] # > BATCH_SIZE (32)
# Execute
results = run_content_classification_inference(test_inputs)
# Assert
assert len(results) == 40
# Verify batching occurred (should have called predict_proba twice)
assert mock_content_model.predict_proba.call_count == 2

View File

@ -1,12 +1,24 @@
from typing import cast
from typing import List
from unittest.mock import Mock
import pytest
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_pipeline import _get_aggregated_chunk_boost_factor
from onyx.indexing.indexing_pipeline import filter_documents
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import IndexChunk
from onyx.natural_language_processing.search_nlp_models import (
ContentClassificationPrediction,
)
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
def create_test_document(
@ -123,3 +135,117 @@ def test_filter_documents_multiple_documents() -> None:
def test_filter_documents_empty_batch() -> None:
result = filter_documents([])
assert len(result) == 0
# Tests for get_aggregated_boost_factor
def create_test_chunk(
content: str, chunk_id: int = 0, doc_id: str = "test_doc"
) -> IndexChunk:
doc = Document(
id=doc_id,
semantic_identifier="test doc",
sections=[],
source=DocumentSource.FILE,
metadata={},
)
return IndexChunk(
chunk_id=chunk_id,
content=content,
source_document=doc,
blurb=content[:50], # First 50 chars as blurb
source_links={0: "test_link"},
section_continuation=False,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
large_chunk_reference_ids=[],
embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]),
title_embedding=None,
image_file_name=None,
)
def test_get_aggregated_boost_factor() -> None:
# Create test chunks - mix of short and long content
chunks = [
create_test_chunk("Short content", 0),
create_test_chunk(
"Long " * (INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1), 1
),
create_test_chunk("Another short chunk", 2),
]
# Mock the classification model
mock_model = Mock()
mock_model.predict.return_value = [
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8),
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.9),
]
# Execute the function
boost_scores = _get_aggregated_chunk_boost_factor(
chunks=chunks, information_content_classification_model=mock_model
)
# Assertions
assert len(boost_scores) == 3
# Check that long content got default boost
assert boost_scores[1] == 1.0
# Check that short content got predicted boosts
assert boost_scores[0] == 0.8
assert boost_scores[2] == 0.9
# Verify model was only called once with the short chunks
mock_model.predict.assert_called_once()
assert len(mock_model.predict.call_args[0][0]) == 2
def test_get_aggregated_boost_factorilure() -> None:
chunks = [
create_test_chunk("Short content 1", 0),
create_test_chunk("Short content 2", 1),
]
# Mock model to fail on batch prediction but succeed on individual predictions
mock_model = Mock()
mock_model.predict.side_effect = [
Exception("Batch prediction failed"), # First call fails
[
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.7)
], # Individual calls succeed
[ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8)],
]
# Execute
boost_scores = _get_aggregated_chunk_boost_factor(
chunks=chunks, information_content_classification_model=mock_model
)
# Assertions
assert len(boost_scores) == 2
assert boost_scores == [0.7, 0.8]
def test_get_aggregated_boost_factor_individual_failure() -> None:
chunks = [
create_test_chunk("Short content", 0),
create_test_chunk("Short content", 1),
]
# Mock model to fail on both batch and individual prediction
mock_model = Mock()
mock_model.predict.side_effect = Exception("Prediction failed")
# Execute and verify it raises an exception
with pytest.raises(Exception) as exc_info:
_get_aggregated_chunk_boost_factor(
chunks=chunks, information_content_classification_model=mock_model
)
assert "Failed to predict content classification for chunk" in str(exc_info.value)