mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 13:22:42 +01:00
Reduce ranking scores for short chunks without actual information (#4098)
* remove title for slack * initial working code * simplification * improvements * name change to information_content_model * avoid boost_score > 1.0 * nit * EL comments and improvements Improvements: - proper import of information content model from cache or HF - warm up for information content model Other: - EL PR review comments * nit * requirements version update * fixed docker file * new home for model_server configs * default off * small updates * YS comments - pt 1 * renaming to chunk_boost & chunk table def * saving and deleting chunk stats in new table * saving and updating chunk stats * improved dict score update * create columns for individual boost factors * RK comments * Update migration * manual import reordering
This commit is contained in:
parent
ba82888e1e
commit
463340b8a1
@ -31,7 +31,8 @@ RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
|
||||
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
|
||||
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
|
@ -0,0 +1,51 @@
|
||||
"""add chunk stats table
|
||||
|
||||
Revision ID: 3781a5eb12cb
|
||||
Revises: df46c75b714e
|
||||
Create Date: 2025-03-10 10:02:30.586666
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3781a5eb12cb"
|
||||
down_revision = "df46c75b714e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chunk_stats",
|
||||
sa.Column("id", sa.String(), primary_key=True, index=True),
|
||||
sa.Column(
|
||||
"document_id",
|
||||
sa.String(),
|
||||
sa.ForeignKey("document.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False),
|
||||
sa.Column("information_content_boost", sa.Float(), nullable=True),
|
||||
sa.Column(
|
||||
"last_modified",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
index=True,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True),
|
||||
sa.UniqueConstraint(
|
||||
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_chunk_sync_status", table_name="chunk_stats")
|
||||
op.drop_table("chunk_stats")
|
@ -3,6 +3,7 @@ from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
|
@ -1,11 +1,14 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from setfit import SetFitModel # type: ignore[import]
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import BatchEncoding # type: ignore
|
||||
from transformers import PreTrainedTokenizer # type: ignore
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.onyx_torch_model import ConnectorClassifier
|
||||
from model_server.onyx_torch_model import HybridClassifier
|
||||
@ -13,11 +16,22 @@ from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
|
||||
from shared_configs.configs import INTENT_MODEL_TAG
|
||||
from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
@ -31,6 +45,10 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
@ -85,7 +103,7 @@ def get_intent_model_tokenizer() -> AutoTokenizer:
|
||||
|
||||
def get_local_intent_model(
|
||||
model_name_or_path: str = INTENT_MODEL_VERSION,
|
||||
tag: str = INTENT_MODEL_TAG,
|
||||
tag: str | None = INTENT_MODEL_TAG,
|
||||
) -> HybridClassifier:
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None:
|
||||
@ -102,7 +120,9 @@ def get_local_intent_model(
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@ -112,6 +132,44 @@ def get_local_intent_model(
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def get_local_information_content_model(
|
||||
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
) -> SetFitModel:
|
||||
global _INFORMATION_CONTENT_MODEL
|
||||
if _INFORMATION_CONTENT_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
logger.notice(
|
||||
f"Loading content information model from local cache: {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
logger.notice(
|
||||
f"Loaded content information model from local cache: {local_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load content information model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(
|
||||
f"Downloading content information model snapshot for {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load content information model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
return _INFORMATION_CONTENT_MODEL
|
||||
|
||||
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
@ -195,6 +253,13 @@ def warm_up_intent_model() -> None:
|
||||
)
|
||||
|
||||
|
||||
def warm_up_information_content_model() -> None:
|
||||
logger.notice("Warming up Content Model") # TODO: add version if needed
|
||||
|
||||
information_content_model = get_local_information_content_model()
|
||||
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
intent_model = get_local_intent_model()
|
||||
@ -218,6 +283,117 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
return intent_probabilities.tolist(), token_positive_probs
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_content_classification_inference(
|
||||
text_inputs: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
"""
|
||||
Assign a score to the segments in question. The model stored in get_local_information_content_model()
|
||||
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
|
||||
In the code outside of the model/inference model servers that score will be converted into the actual
|
||||
boost factor.
|
||||
"""
|
||||
|
||||
def _prob_to_score(prob: float) -> float:
|
||||
"""
|
||||
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
|
||||
"""
|
||||
_MIN_BASE_SCORE = 0.25
|
||||
_MAX_BASE_SCORE = 0.75
|
||||
if prob < _MIN_BASE_SCORE:
|
||||
raw_score = 0.0
|
||||
elif prob < _MAX_BASE_SCORE:
|
||||
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
|
||||
else:
|
||||
raw_score = 1.0
|
||||
return (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
+ (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
)
|
||||
* raw_score
|
||||
)
|
||||
|
||||
_BATCH_SIZE = 32
|
||||
content_model = get_local_information_content_model()
|
||||
|
||||
# Process inputs in batches
|
||||
all_output_classes: list[int] = []
|
||||
all_base_output_probabilities: list[float] = []
|
||||
|
||||
for i in range(0, len(text_inputs), _BATCH_SIZE):
|
||||
batch = text_inputs[i : i + _BATCH_SIZE]
|
||||
batch_with_prefix = []
|
||||
batch_indices = []
|
||||
|
||||
# Pre-allocate results for this batch
|
||||
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
|
||||
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
|
||||
|
||||
# Pre-process batch to handle long input exceptions
|
||||
for j, text in enumerate(batch):
|
||||
if len(text) == 0:
|
||||
# if no input, treat as non-informative from the model's perspective
|
||||
batch_output_classes[j] = np.array(0)
|
||||
batch_probabilities[j] = np.array(0.0)
|
||||
logger.warning("Input for Content Information Model is empty")
|
||||
|
||||
elif (
|
||||
len(text.split())
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
):
|
||||
# if input is short, use the model
|
||||
batch_with_prefix.append(
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
|
||||
)
|
||||
batch_indices.append(j)
|
||||
else:
|
||||
# if longer than cutoff, treat as informative (stay with default), but issue warning
|
||||
logger.warning("Input for Content Information Model too long")
|
||||
|
||||
if batch_with_prefix: # Only run model if we have valid inputs
|
||||
# Get predictions for the batch
|
||||
model_output_classes = content_model(batch_with_prefix)
|
||||
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
|
||||
|
||||
# Place results in the correct positions
|
||||
for idx, batch_idx in enumerate(batch_indices):
|
||||
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
|
||||
batch_probabilities[batch_idx] = model_output_probabilities[idx][
|
||||
1
|
||||
].numpy() # x[1] is prob of the positive class
|
||||
|
||||
all_output_classes.extend([int(x) for x in batch_output_classes])
|
||||
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
|
||||
|
||||
logits = [
|
||||
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
|
||||
for p in all_base_output_probabilities
|
||||
]
|
||||
scaled_logits = [
|
||||
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||
for logit in logits
|
||||
]
|
||||
output_probabilities_with_temp = [
|
||||
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
|
||||
for scaled_logit in scaled_logits
|
||||
]
|
||||
|
||||
prediction_scores = [
|
||||
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
|
||||
]
|
||||
|
||||
content_classification_predictions = [
|
||||
ContentClassificationPrediction(
|
||||
predicted_label=predicted_label, content_boost_factor=output_score
|
||||
)
|
||||
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
|
||||
]
|
||||
|
||||
return content_classification_predictions
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
@ -362,3 +538,10 @@ async def process_analysis_request(
|
||||
|
||||
is_keyword, keywords = run_analysis(intent_request)
|
||||
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
|
||||
|
||||
|
||||
@router.post("/content-classification")
|
||||
async def process_content_classification_request(
|
||||
content_classification_requests: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
return run_content_classification_inference(content_classification_requests)
|
||||
|
@ -13,6 +13,7 @@ from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_information_content_model
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.management_endpoints import router as management_router
|
||||
@ -74,9 +75,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice(
|
||||
"The intent model should run on the model server. The information content model should not run here."
|
||||
)
|
||||
warm_up_intent_model()
|
||||
else:
|
||||
logger.notice("This model server should only run document indexing.")
|
||||
logger.notice(
|
||||
"The content information model should run on the indexing model server. The intent model should not run here."
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
|
||||
yield
|
||||
|
||||
|
@ -563,6 +563,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
# aggregated_boost_factor=doc.aggregated_boost_factor,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
|
@ -53,6 +53,9 @@ from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
@ -348,6 +351,8 @@ def _run_indexing(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
document_index = get_default_document_index(
|
||||
index_attempt_start.search_settings,
|
||||
None,
|
||||
@ -356,6 +361,7 @@ def _run_indexing(
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
|
@ -132,3 +132,10 @@ if _LITELLM_EXTRA_BODY_RAW:
|
||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Whether and how to lower scores for short chunks w/o relevant context
|
||||
# Evaluated via custom ML model
|
||||
|
||||
USE_INFORMATION_CONTENT_CLASSIFICATION = (
|
||||
os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true"
|
||||
)
|
||||
|
@ -220,7 +220,6 @@ def thread_to_doc(
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=doc_sem_id,
|
||||
doc_updated_at=get_latest_message_time(thread),
|
||||
title="", # slack docs don't really have a "title"
|
||||
primary_owners=valid_experts,
|
||||
metadata={"Channel": channel["name"]},
|
||||
)
|
||||
|
63
backend/onyx/db/chunk.py
Normal file
63
backend/onyx/db/chunk.py
Normal file
@ -0,0 +1,63 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import ChunkStats
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
|
||||
|
||||
def update_chunk_boost_components__no_commit(
|
||||
chunk_data: list[UpdatableChunkData],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Updates the chunk_boost_components for chunks in the database.
|
||||
|
||||
Args:
|
||||
chunk_data: List of dicts containing chunk_id, document_id, and boost_score
|
||||
db_session: SQLAlchemy database session
|
||||
"""
|
||||
if not chunk_data:
|
||||
return
|
||||
|
||||
for data in chunk_data:
|
||||
chunk_in_doc_id = int(data.chunk_id)
|
||||
if chunk_in_doc_id < 0:
|
||||
raise ValueError(f"Chunk ID is empty for chunk {data}")
|
||||
|
||||
chunk_document_id = f"{data.document_id}" f"__{chunk_in_doc_id}"
|
||||
chunk_stats = (
|
||||
db_session.query(ChunkStats)
|
||||
.filter(
|
||||
ChunkStats.id == chunk_document_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
score = data.boost_score
|
||||
|
||||
if chunk_stats:
|
||||
chunk_stats.information_content_boost = score
|
||||
chunk_stats.last_modified = datetime.now(timezone.utc)
|
||||
db_session.add(chunk_stats)
|
||||
else:
|
||||
# do not save new chunks with a neutral boost score
|
||||
if score == 1.0:
|
||||
continue
|
||||
# Create new record
|
||||
chunk_stats = ChunkStats(
|
||||
document_id=data.document_id,
|
||||
chunk_in_doc_id=chunk_in_doc_id,
|
||||
information_content_boost=score,
|
||||
)
|
||||
db_session.add(chunk_stats)
|
||||
|
||||
|
||||
def delete_chunk_stats_by_connector_credential_pair__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""This deletes just chunk stats in postgres."""
|
||||
stmt = delete(ChunkStats).where(ChunkStats.document_id.in_(document_ids))
|
||||
|
||||
db_session.execute(stmt)
|
@ -23,6 +23,7 @@ from sqlalchemy.sql.expression import null
|
||||
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import AccessType
|
||||
@ -562,6 +563,18 @@ def delete_documents_complete__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""This completely deletes the documents from the db, including all foreign key relationships"""
|
||||
|
||||
# Start by deleting the chunk stats for the documents
|
||||
delete_chunk_stats_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_chunk_stats_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
|
||||
delete_document_feedback_for_documents__no_commit(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
|
@ -591,6 +591,55 @@ class Document(Base):
|
||||
)
|
||||
|
||||
|
||||
class ChunkStats(Base):
|
||||
__tablename__ = "chunk_stats"
|
||||
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Onyx)
|
||||
id: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
default=lambda context: (
|
||||
f"{context.get_current_parameters()['document_id']}"
|
||||
f"__{context.get_current_parameters()['chunk_in_doc_id']}"
|
||||
),
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Reference to parent document
|
||||
document_id: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("document.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
chunk_in_doc_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
information_content_boost: Mapped[float | None] = mapped_column(
|
||||
Float, nullable=True
|
||||
)
|
||||
|
||||
last_modified: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, index=True, default=func.now()
|
||||
)
|
||||
last_synced: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, index=True
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_chunk_sync_status",
|
||||
last_modified,
|
||||
last_synced,
|
||||
),
|
||||
UniqueConstraint(
|
||||
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tag"
|
||||
|
||||
|
@ -101,6 +101,7 @@ class VespaDocumentFields:
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
hidden: bool | None = None
|
||||
aggregated_chunk_boost_factor: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -80,6 +80,11 @@ schema DANSWER_CHUNK_NAME {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
}
|
||||
# Field to indicate whether a short chunk is a low content chunk
|
||||
field aggregated_chunk_boost_factor type float {
|
||||
indexing: attribute
|
||||
}
|
||||
|
||||
# Needs to have a separate Attribute list for efficient filtering
|
||||
field metadata_list type array<string> {
|
||||
indexing: summary | attribute
|
||||
@ -142,6 +147,11 @@ schema DANSWER_CHUNK_NAME {
|
||||
expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0)
|
||||
}
|
||||
|
||||
function inline aggregated_chunk_boost() {
|
||||
# Aggregated boost factor, currently only used for information content classification
|
||||
expression: if(isNan(attribute(aggregated_chunk_boost_factor)) == 1, 1.0, attribute(aggregated_chunk_boost_factor))
|
||||
}
|
||||
|
||||
# Document score decays from 1 to 0.75 as age of last updated time increases
|
||||
function inline recency_bias() {
|
||||
expression: max(1 / (1 + query(decay_factor) * document_age), 0.75)
|
||||
@ -199,6 +209,8 @@ schema DANSWER_CHUNK_NAME {
|
||||
* document_boost
|
||||
# Decay factor based on time document was last updated
|
||||
* recency_bias
|
||||
# Boost based on aggregated boost calculation
|
||||
* aggregated_chunk_boost
|
||||
}
|
||||
rerank-count: 1000
|
||||
}
|
||||
@ -210,6 +222,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
closeness(field, embeddings)
|
||||
document_boost
|
||||
recency_bias
|
||||
aggregated_chunk_boost
|
||||
closest(embeddings)
|
||||
}
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
|
||||
from onyx.document_index.vespa_constants import BLURB
|
||||
from onyx.document_index.vespa_constants import BOOST
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
@ -201,6 +202,7 @@ def _index_vespa_chunk(
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
IMAGE_FILE_NAME: chunk.image_file_name,
|
||||
BOOST: chunk.boost,
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
|
||||
}
|
||||
|
||||
if multitenant:
|
||||
|
@ -72,6 +72,7 @@ METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
METADATA_SUFFIX = "metadata_suffix"
|
||||
BOOST = "boost"
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
|
||||
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
|
||||
PRIMARY_OWNERS = "primary_owners"
|
||||
SECONDARY_OWNERS = "secondary_owners"
|
||||
@ -97,6 +98,7 @@ YQL_BASE = (
|
||||
f"{SECTION_CONTINUATION}, "
|
||||
f"{IMAGE_FILE_NAME}, "
|
||||
f"{BOOST}, "
|
||||
f"{AGGREGATED_CHUNK_BOOST_FACTOR}, "
|
||||
f"{HIDDEN}, "
|
||||
f"{DOC_UPDATED_AT}, "
|
||||
f"{PRIMARY_OWNERS}, "
|
||||
|
0
backend/onyx/indexing/content_classification.py
Normal file
0
backend/onyx/indexing/content_classification.py
Normal file
@ -11,6 +11,7 @@ from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
@ -22,6 +23,7 @@ from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.chunk import update_chunk_boost_components__no_commit
|
||||
from onyx.db.document import fetch_chunk_counts_for_documents
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
|
||||
@ -52,10 +54,19 @@ from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -136,6 +147,72 @@ def _upsert_documents_in_db(
|
||||
)
|
||||
|
||||
|
||||
def _get_aggregated_chunk_boost_factor(
|
||||
chunks: list[IndexChunk],
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
) -> list[float]:
|
||||
"""Calculates the aggregated boost factor for a chunk based on its content."""
|
||||
|
||||
short_chunk_content_dict = {
|
||||
chunk_num: chunk.content
|
||||
for chunk_num, chunk in enumerate(chunks)
|
||||
if len(chunk.content.split())
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
}
|
||||
short_chunk_contents = list(short_chunk_content_dict.values())
|
||||
short_chunk_keys = list(short_chunk_content_dict.keys())
|
||||
|
||||
try:
|
||||
predictions = information_content_classification_model.predict(
|
||||
short_chunk_contents
|
||||
)
|
||||
# Create a mapping of chunk positions to their scores
|
||||
score_map = {
|
||||
short_chunk_keys[i]: prediction.content_boost_factor
|
||||
for i, prediction in enumerate(predictions)
|
||||
}
|
||||
# Default to 1.0 for longer chunks, use predicted score for short chunks
|
||||
chunk_content_scores = [score_map.get(i, 1.0) for i in range(len(chunks))]
|
||||
|
||||
return chunk_content_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error predicting content classification for chunks: {e}. Falling back to individual examples."
|
||||
)
|
||||
|
||||
chunks_with_scores: list[IndexChunk] = []
|
||||
chunk_content_scores = []
|
||||
|
||||
for chunk in chunks:
|
||||
if (
|
||||
len(chunk.content.split())
|
||||
> INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
):
|
||||
chunk_content_scores.append(1.0)
|
||||
chunks_with_scores.append(chunk)
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_content_scores.append(
|
||||
information_content_classification_model.predict([chunk.content])[
|
||||
0
|
||||
].content_boost_factor
|
||||
)
|
||||
chunks_with_scores.append(chunk)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error predicting content classification for chunk: {e}."
|
||||
)
|
||||
|
||||
raise Exception(
|
||||
f"Failed to predict content classification for chunk {chunk.chunk_id} "
|
||||
f"from document {chunk.source_document.id}"
|
||||
) from e
|
||||
|
||||
return chunk_content_scores
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
documents: list[Document], db_docs: list[DBDocument]
|
||||
) -> list[Document]:
|
||||
@ -165,6 +242,7 @@ def index_doc_batch_with_handler(
|
||||
*,
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
@ -176,6 +254,7 @@ def index_doc_batch_with_handler(
|
||||
index_pipeline_result = index_doc_batch(
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
document_batch=document_batch,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
@ -450,6 +529,7 @@ def index_doc_batch(
|
||||
document_batch: list[Document],
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
@ -526,7 +606,23 @@ def index_doc_batch(
|
||||
else ([], [])
|
||||
)
|
||||
|
||||
chunk_content_scores = (
|
||||
_get_aggregated_chunk_boost_factor(
|
||||
chunks_with_embeddings, information_content_classification_model
|
||||
)
|
||||
if USE_INFORMATION_CONTENT_CLASSIFICATION
|
||||
else [1.0] * len(chunks_with_embeddings)
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk.chunk_id,
|
||||
document_id=chunk.source_document.id,
|
||||
boost_score=score,
|
||||
)
|
||||
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
|
||||
]
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
@ -579,8 +675,9 @@ def index_doc_batch(
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
)
|
||||
for chunk in chunks_with_embeddings
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
@ -665,6 +762,11 @@ def index_doc_batch(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# save the chunk boost components to postgres
|
||||
update_chunk_boost_components__no_commit(
|
||||
chunk_data=updatable_chunk_data, db_session=db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
result = IndexingPipelineResult(
|
||||
@ -680,6 +782,7 @@ def index_doc_batch(
|
||||
def build_indexing_pipeline(
|
||||
*,
|
||||
embedder: IndexingEmbedder,
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
@ -703,6 +806,7 @@ def build_indexing_pipeline(
|
||||
index_doc_batch_with_handler,
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
db_session=db_session,
|
||||
|
@ -83,13 +83,16 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
document_sets: all document sets the source document for this chunk is a part
|
||||
of. This is used for filtering / personas.
|
||||
boost: influences the ranking of this chunk at query time. Positive -> ranked higher,
|
||||
negative -> ranked lower.
|
||||
negative -> ranked lower. Not included in aggregated boost calculation
|
||||
for legacy reasons.
|
||||
aggregated_chunk_boost_factor: represents the aggregated chunk-level boost (currently: information content)
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
aggregated_chunk_boost_factor: float
|
||||
|
||||
@classmethod
|
||||
def from_index_chunk(
|
||||
@ -98,6 +101,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
aggregated_chunk_boost_factor: float,
|
||||
tenant_id: str,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
@ -106,6 +110,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
@ -179,3 +184,9 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
class MultipassConfig(BaseModel):
|
||||
multipass_indexing: bool
|
||||
enable_large_chunks: bool
|
||||
|
||||
|
||||
class UpdatableChunkData(BaseModel):
|
||||
chunk_id: int
|
||||
document_id: str
|
||||
boost_score: float
|
||||
|
@ -29,6 +29,8 @@ from onyx.natural_language_processing.exceptions import (
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
@ -36,9 +38,11 @@ from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import InformationContentClassificationResponses
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
@ -377,6 +381,31 @@ class QueryAnalysisModel:
|
||||
return response_model.is_keyword, response_model.keywords
|
||||
|
||||
|
||||
class InformationContentClassificationModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_server_host: str = INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port: int = INDEXING_MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.content_server_endpoint = (
|
||||
model_server_url + "/custom/content-classification"
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
queries: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
response = requests.post(self.content_server_endpoint, json=queries)
|
||||
response.raise_for_status()
|
||||
|
||||
model_responses = InformationContentClassificationResponses(
|
||||
information_content_classifications=response.json()
|
||||
)
|
||||
|
||||
return model_responses.information_content_classifications
|
||||
|
||||
|
||||
class ConnectorClassificationModel:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -98,6 +98,7 @@ def _create_indexable_chunks(
|
||||
boost=DEFAULT_BOOST,
|
||||
large_chunk_id=None,
|
||||
image_file_name=None,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
)
|
||||
|
||||
chunks.append(chunk)
|
||||
|
@ -19,6 +19,9 @@ from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.server.onyx_api.models import DocMinimalInfo
|
||||
from onyx.server.onyx_api.models import IngestionDocument
|
||||
from onyx.server.onyx_api.models import IngestionResult
|
||||
@ -102,8 +105,11 @@ def upsert_ingestion_doc(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=index_embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=curr_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
@ -138,6 +144,7 @@ def upsert_ingestion_doc(
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
embedder=new_index_embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=sec_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
|
@ -25,7 +25,7 @@ google-auth-oauthlib==1.0.0
|
||||
httpcore==1.0.5
|
||||
httpx[http2]==0.27.0
|
||||
httpx-oauth==0.15.1
|
||||
huggingface-hub==0.20.1
|
||||
huggingface-hub==0.29.0
|
||||
inflection==0.5.1
|
||||
jira==3.5.1
|
||||
jsonref==1.1.0
|
||||
@ -71,6 +71,7 @@ requests==2.32.2
|
||||
requests-oauthlib==1.3.1
|
||||
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
|
||||
rfc3986==1.5.0
|
||||
setfit==1.1.1
|
||||
simple-salesforce==1.12.6
|
||||
slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
@ -78,7 +79,7 @@ starlette==0.36.3
|
||||
supervisor==4.2.5
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
transformers==4.39.2
|
||||
transformers==4.49.0
|
||||
unstructured==0.15.1
|
||||
unstructured-client==0.25.4
|
||||
uvicorn==0.21.1
|
||||
|
@ -14,7 +14,7 @@ pytest-asyncio==0.22.0
|
||||
pytest==7.4.4
|
||||
reorder-python-imports==3.9.0
|
||||
ruff==0.0.286
|
||||
sentence-transformers==2.6.1
|
||||
sentence-transformers==3.4.1
|
||||
trafilatura==1.12.2
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
types-html5lib==1.1.11.13
|
||||
|
@ -7,9 +7,10 @@ openai==1.61.0
|
||||
pydantic==2.8.2
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
sentence-transformers==3.4.1
|
||||
setfit==1.1.1
|
||||
torch==2.2.0
|
||||
transformers==4.39.2
|
||||
transformers==4.49.0
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.61.16
|
||||
|
@ -161,17 +161,21 @@ overview_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/overview",
|
||||
title=overview_title,
|
||||
content=overview,
|
||||
title_embedding=model.encode(f"search_document: {overview_title}"),
|
||||
content_embedding=model.encode(f"search_document: {overview_title}\n{overview}"),
|
||||
title_embedding=list(model.encode(f"search_document: {overview_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {overview_title}\n{overview}")
|
||||
),
|
||||
)
|
||||
|
||||
enterprise_search_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/enterprise_search",
|
||||
title=enterprise_search_title,
|
||||
content=enterprise_search_1,
|
||||
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
|
||||
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@ -179,9 +183,11 @@ enterprise_search_doc_2 = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/enterprise_search",
|
||||
title=enterprise_search_title,
|
||||
content=enterprise_search_2,
|
||||
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
|
||||
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
|
||||
)
|
||||
),
|
||||
chunk_ind=1,
|
||||
)
|
||||
@ -190,9 +196,9 @@ ai_platform_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/ai_platform",
|
||||
title=ai_platform_title,
|
||||
content=ai_platform,
|
||||
title_embedding=model.encode(f"search_document: {ai_platform_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {ai_platform_title}\n{ai_platform}"
|
||||
title_embedding=list(model.encode(f"search_document: {ai_platform_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {ai_platform_title}\n{ai_platform}")
|
||||
),
|
||||
)
|
||||
|
||||
@ -200,9 +206,9 @@ customer_support_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/support",
|
||||
title=customer_support_title,
|
||||
content=customer_support,
|
||||
title_embedding=model.encode(f"search_document: {customer_support_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {customer_support_title}\n{customer_support}"
|
||||
title_embedding=list(model.encode(f"search_document: {customer_support_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {customer_support_title}\n{customer_support}")
|
||||
),
|
||||
)
|
||||
|
||||
@ -210,17 +216,17 @@ sales_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/sales",
|
||||
title=sales_title,
|
||||
content=sales,
|
||||
title_embedding=model.encode(f"search_document: {sales_title}"),
|
||||
content_embedding=model.encode(f"search_document: {sales_title}\n{sales}"),
|
||||
title_embedding=list(model.encode(f"search_document: {sales_title}")),
|
||||
content_embedding=list(model.encode(f"search_document: {sales_title}\n{sales}")),
|
||||
)
|
||||
|
||||
operations_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/operations",
|
||||
title=operations_title,
|
||||
content=operations,
|
||||
title_embedding=model.encode(f"search_document: {operations_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {operations_title}\n{operations}"
|
||||
title_embedding=list(model.encode(f"search_document: {operations_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {operations_title}\n{operations}")
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -99,6 +99,7 @@ def generate_dummy_chunk(
|
||||
),
|
||||
document_sets={document_set for document_set in document_set_names},
|
||||
boost=random.randint(-1, 1),
|
||||
aggregated_chunk_boost_factor=random.random(),
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
)
|
||||
|
||||
|
@ -23,9 +23,11 @@ INDEXING_MODEL_SERVER_PORT = int(
|
||||
# Onyx custom Deep Learning Models
|
||||
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
|
||||
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
|
||||
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||
INTENT_MODEL_TAG = "v1.0.3"
|
||||
|
||||
INTENT_MODEL_VERSION = "onyx-dot-app/hybrid-intent-token-classifier"
|
||||
# INTENT_MODEL_TAG = "v1.0.3"
|
||||
INTENT_MODEL_TAG: str | None = None
|
||||
INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model"
|
||||
INFORMATION_CONTENT_MODEL_TAG: str | None = None
|
||||
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
@ -277,3 +279,20 @@ SUPPORTED_EMBEDDING_MODELS = [
|
||||
index_name="danswer_chunk_intfloat_multilingual_e5_small",
|
||||
),
|
||||
]
|
||||
# Maximum (least severe) downgrade factor for chunks above the cutoff
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = float(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX") or 1.0
|
||||
)
|
||||
# Minimum (most severe) downgrade factor for short chunks below the cutoff if no content
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.7
|
||||
)
|
||||
# Temperature for the information content classification model
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
|
||||
)
|
||||
# Cutoff below which we start using the information content classification model
|
||||
# (cutoff length number itself is still considered 'short'))
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10
|
||||
)
|
||||
|
@ -4,6 +4,7 @@ from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
Embedding = list[float]
|
||||
|
||||
|
||||
@ -73,7 +74,20 @@ class IntentResponse(BaseModel):
|
||||
keywords: list[str]
|
||||
|
||||
|
||||
class InformationContentClassificationRequests(BaseModel):
|
||||
queries: list[str]
|
||||
|
||||
|
||||
class SupportedEmbeddingModel(BaseModel):
|
||||
name: str
|
||||
dim: int
|
||||
index_name: str
|
||||
|
||||
|
||||
class ContentClassificationPrediction(BaseModel):
|
||||
predicted_label: int
|
||||
content_boost_factor: float
|
||||
|
||||
|
||||
class InformationContentClassificationResponses(BaseModel):
|
||||
information_content_classifications: list[ContentClassificationPrediction]
|
||||
|
145
backend/tests/unit/model_server/test_custom_models.py
Normal file
145
backend/tests/unit/model_server/test_custom_models.py
Normal file
@ -0,0 +1,145 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pytest
|
||||
|
||||
from model_server.custom_models import run_content_classification_inference
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_content_model() -> Mock:
|
||||
model = Mock()
|
||||
|
||||
# Create actual numpy arrays for the mock returns
|
||||
predict_output = np.array(
|
||||
[1, 0] * 50, dtype=np.int64
|
||||
) # Pre-allocate enough elements
|
||||
proba_output = np.array(
|
||||
[[0.3, 0.7], [0.7, 0.3]] * 50, dtype=np.float64
|
||||
) # Pre-allocate enough elements
|
||||
|
||||
# Create a mock tensor that has a numpy method and supports indexing
|
||||
class MockTensor:
|
||||
def __init__(self, value: npt.NDArray[Any]) -> None:
|
||||
self.value = value
|
||||
|
||||
def numpy(self) -> npt.NDArray[Any]:
|
||||
return self.value
|
||||
|
||||
def __getitem__(self, idx: Any) -> Any:
|
||||
result = self.value[idx]
|
||||
# Wrap scalar values back in MockTensor
|
||||
if isinstance(result, (np.float64, np.int64)):
|
||||
return MockTensor(np.array([result]))
|
||||
return MockTensor(result)
|
||||
|
||||
# Mock the direct call to return a MockTensor for each input
|
||||
def model_call(inputs: list[str]) -> list[MockTensor]:
|
||||
batch_size = len(inputs)
|
||||
return [MockTensor(predict_output[i : i + 1]) for i in range(batch_size)]
|
||||
|
||||
model.side_effect = model_call
|
||||
|
||||
# Mock predict_proba to return MockTensor-wrapped numpy array
|
||||
def predict_proba_call(x: list[str]) -> MockTensor:
|
||||
batch_size = len(x)
|
||||
return MockTensor(proba_output[:batch_size])
|
||||
|
||||
model.predict_proba.side_effect = predict_proba_call
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@patch("model_server.custom_models.get_local_information_content_model")
|
||||
def test_run_content_classification_inference(
|
||||
mock_get_model: Mock,
|
||||
mock_content_model: Mock,
|
||||
) -> None:
|
||||
"""
|
||||
Test the content classification inference function.
|
||||
Verifies that the function correctly processes text inputs and returns appropriate predictions.
|
||||
"""
|
||||
# Setup
|
||||
mock_get_model.return_value = mock_content_model
|
||||
|
||||
test_inputs = [
|
||||
"Imagine a short text with content",
|
||||
"Imagine a short text without content",
|
||||
"x "
|
||||
* (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1
|
||||
), # Long input that exceeds maximal length for when the model should be applied
|
||||
"", # Empty input
|
||||
]
|
||||
|
||||
# Execute
|
||||
results = run_content_classification_inference(test_inputs)
|
||||
|
||||
# Assert
|
||||
assert len(results) == len(test_inputs)
|
||||
assert all(isinstance(r, ContentClassificationPrediction) for r in results)
|
||||
|
||||
# Check each prediction has expected attributes and ranges
|
||||
for result_num, result in enumerate(results):
|
||||
assert hasattr(result, "predicted_label")
|
||||
assert hasattr(result, "content_boost_factor")
|
||||
assert isinstance(result.predicted_label, int)
|
||||
assert isinstance(result.content_boost_factor, float)
|
||||
assert (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
<= result.content_boost_factor
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
)
|
||||
if result_num == 2:
|
||||
assert (
|
||||
result.content_boost_factor
|
||||
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
)
|
||||
assert result.predicted_label == 1
|
||||
elif result_num == 3:
|
||||
assert (
|
||||
result.content_boost_factor
|
||||
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
)
|
||||
assert result.predicted_label == 0
|
||||
|
||||
# Verify model handling of long inputs
|
||||
mock_content_model.predict_proba.reset_mock()
|
||||
long_input = ["x " * 1000] # Definitely exceeds MAX_LENGTH
|
||||
results = run_content_classification_inference(long_input)
|
||||
assert len(results) == 1
|
||||
assert (
|
||||
mock_content_model.predict_proba.call_count == 0
|
||||
) # Should skip model call for too-long input
|
||||
|
||||
|
||||
@patch("model_server.custom_models.get_local_information_content_model")
|
||||
def test_batch_processing(
|
||||
mock_get_model: Mock,
|
||||
mock_content_model: Mock,
|
||||
) -> None:
|
||||
"""
|
||||
Test that the function correctly handles batch processing of inputs.
|
||||
"""
|
||||
# Setup
|
||||
mock_get_model.return_value = mock_content_model
|
||||
|
||||
# Create test input larger than batch size
|
||||
test_inputs = [f"Test input {i}" for i in range(40)] # > BATCH_SIZE (32)
|
||||
|
||||
# Execute
|
||||
results = run_content_classification_inference(test_inputs)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 40
|
||||
# Verify batching occurred (should have called predict_proba twice)
|
||||
assert mock_content_model.predict_proba.call_count == 2
|
@ -1,12 +1,24 @@
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.indexing_pipeline import _get_aggregated_chunk_boost_factor
|
||||
from onyx.indexing.indexing_pipeline import filter_documents
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
ContentClassificationPrediction,
|
||||
)
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
def create_test_document(
|
||||
@ -123,3 +135,117 @@ def test_filter_documents_multiple_documents() -> None:
|
||||
def test_filter_documents_empty_batch() -> None:
|
||||
result = filter_documents([])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
# Tests for get_aggregated_boost_factor
|
||||
|
||||
|
||||
def create_test_chunk(
|
||||
content: str, chunk_id: int = 0, doc_id: str = "test_doc"
|
||||
) -> IndexChunk:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test doc",
|
||||
sections=[],
|
||||
source=DocumentSource.FILE,
|
||||
metadata={},
|
||||
)
|
||||
return IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
content=content,
|
||||
source_document=doc,
|
||||
blurb=content[:50], # First 50 chars as blurb
|
||||
source_links={0: "test_link"},
|
||||
section_continuation=False,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
large_chunk_reference_ids=[],
|
||||
embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]),
|
||||
title_embedding=None,
|
||||
image_file_name=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_aggregated_boost_factor() -> None:
|
||||
# Create test chunks - mix of short and long content
|
||||
chunks = [
|
||||
create_test_chunk("Short content", 0),
|
||||
create_test_chunk(
|
||||
"Long " * (INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1), 1
|
||||
),
|
||||
create_test_chunk("Another short chunk", 2),
|
||||
]
|
||||
|
||||
# Mock the classification model
|
||||
mock_model = Mock()
|
||||
mock_model.predict.return_value = [
|
||||
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8),
|
||||
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.9),
|
||||
]
|
||||
|
||||
# Execute the function
|
||||
boost_scores = _get_aggregated_chunk_boost_factor(
|
||||
chunks=chunks, information_content_classification_model=mock_model
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert len(boost_scores) == 3
|
||||
|
||||
# Check that long content got default boost
|
||||
assert boost_scores[1] == 1.0
|
||||
|
||||
# Check that short content got predicted boosts
|
||||
assert boost_scores[0] == 0.8
|
||||
assert boost_scores[2] == 0.9
|
||||
|
||||
# Verify model was only called once with the short chunks
|
||||
mock_model.predict.assert_called_once()
|
||||
assert len(mock_model.predict.call_args[0][0]) == 2
|
||||
|
||||
|
||||
def test_get_aggregated_boost_factorilure() -> None:
|
||||
chunks = [
|
||||
create_test_chunk("Short content 1", 0),
|
||||
create_test_chunk("Short content 2", 1),
|
||||
]
|
||||
|
||||
# Mock model to fail on batch prediction but succeed on individual predictions
|
||||
mock_model = Mock()
|
||||
mock_model.predict.side_effect = [
|
||||
Exception("Batch prediction failed"), # First call fails
|
||||
[
|
||||
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.7)
|
||||
], # Individual calls succeed
|
||||
[ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8)],
|
||||
]
|
||||
|
||||
# Execute
|
||||
boost_scores = _get_aggregated_chunk_boost_factor(
|
||||
chunks=chunks, information_content_classification_model=mock_model
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert len(boost_scores) == 2
|
||||
assert boost_scores == [0.7, 0.8]
|
||||
|
||||
|
||||
def test_get_aggregated_boost_factor_individual_failure() -> None:
|
||||
chunks = [
|
||||
create_test_chunk("Short content", 0),
|
||||
create_test_chunk("Short content", 1),
|
||||
]
|
||||
|
||||
# Mock model to fail on both batch and individual prediction
|
||||
mock_model = Mock()
|
||||
mock_model.predict.side_effect = Exception("Prediction failed")
|
||||
|
||||
# Execute and verify it raises an exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
_get_aggregated_chunk_boost_factor(
|
||||
chunks=chunks, information_content_classification_model=mock_model
|
||||
)
|
||||
|
||||
assert "Failed to predict content classification for chunk" in str(exc_info.value)
|
||||
|
Loading…
x
Reference in New Issue
Block a user