diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server index 4441eac7f..9a140f5e3 100644 --- a/backend/Dockerfile.model_server +++ b/backend/Dockerfile.model_server @@ -31,7 +31,8 @@ RUN python -c "from transformers import AutoTokenizer; \ AutoTokenizer.from_pretrained('distilbert-base-uncased'); \ AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \ from huggingface_hub import snapshot_download; \ -snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \ +snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \ +snapshot_download(repo_id='onyx-dot-app/information-content-model'); \ snapshot_download('nomic-ai/nomic-embed-text-v1'); \ snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \ from sentence_transformers import SentenceTransformer; \ diff --git a/backend/alembic/versions/3781a5eb12cb_add_chunk_stats_table.py b/backend/alembic/versions/3781a5eb12cb_add_chunk_stats_table.py new file mode 100644 index 000000000..2a7811a6c --- /dev/null +++ b/backend/alembic/versions/3781a5eb12cb_add_chunk_stats_table.py @@ -0,0 +1,51 @@ +"""add chunk stats table + +Revision ID: 3781a5eb12cb +Revises: df46c75b714e +Create Date: 2025-03-10 10:02:30.586666 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "3781a5eb12cb" +down_revision = "df46c75b714e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "chunk_stats", + sa.Column("id", sa.String(), primary_key=True, index=True), + sa.Column( + "document_id", + sa.String(), + sa.ForeignKey("document.id"), + nullable=False, + index=True, + ), + sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False), + sa.Column("information_content_boost", sa.Float(), nullable=True), + sa.Column( + "last_modified", + sa.DateTime(timezone=True), + nullable=False, + index=True, + server_default=sa.func.now(), + ), + sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True), + sa.UniqueConstraint( + "document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk" + ), + ) + + op.create_index( + "ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"] + ) + + +def downgrade() -> None: + op.drop_index("ix_chunk_sync_status", table_name="chunk_stats") + op.drop_table("chunk_stats") diff --git a/backend/ee/onyx/external_permissions/confluence/doc_sync.py b/backend/ee/onyx/external_permissions/confluence/doc_sync.py index 8ed076a3c..981aa3acb 100644 --- a/backend/ee/onyx/external_permissions/confluence/doc_sync.py +++ b/backend/ee/onyx/external_permissions/confluence/doc_sync.py @@ -2,6 +2,7 @@ Rules defined here: https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.html """ +from collections.abc import Generator from typing import Any from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC @@ -263,13 +264,11 @@ def _fetch_all_page_restrictions( space_permissions_by_space_key: dict[str, ExternalAccess], is_cloud: bool, callback: IndexingHeartbeatInterface | None, -) -> list[DocExternalAccess]: +) -> Generator[DocExternalAccess, None, None]: """ For all pages, if a page has restrictions, then use those restrictions. Otherwise, use the space's restrictions. """ - document_restrictions: list[DocExternalAccess] = [] - for slim_doc in slim_docs: if callback: if callback.should_stop(): @@ -286,11 +285,9 @@ def _fetch_all_page_restrictions( confluence_client=confluence_client, perm_sync_data=slim_doc.perm_sync_data, ): - document_restrictions.append( - DocExternalAccess( - doc_id=slim_doc.id, - external_access=restrictions, - ) + yield DocExternalAccess( + doc_id=slim_doc.id, + external_access=restrictions, ) # If there are restrictions, then we don't need to use the space's restrictions continue @@ -324,11 +321,9 @@ def _fetch_all_page_restrictions( continue # If there are no restrictions, then use the space's restrictions - document_restrictions.append( - DocExternalAccess( - doc_id=slim_doc.id, - external_access=space_permissions, - ) + yield DocExternalAccess( + doc_id=slim_doc.id, + external_access=space_permissions, ) if ( not space_permissions.is_public @@ -342,13 +337,12 @@ def _fetch_all_page_restrictions( ) logger.debug("Finished fetching all page restrictions for space") - return document_restrictions def confluence_doc_sync( cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None, -) -> list[DocExternalAccess]: +) -> Generator[DocExternalAccess, None, None]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -387,7 +381,7 @@ def confluence_doc_sync( slim_docs.extend(doc_batch) logger.debug("Fetching all page restrictions for space") - return _fetch_all_page_restrictions( + yield from _fetch_all_page_restrictions( confluence_client=confluence_connector.confluence_client, slim_docs=slim_docs, space_permissions_by_space_key=space_permissions_by_space_key, diff --git a/backend/ee/onyx/external_permissions/gmail/doc_sync.py b/backend/ee/onyx/external_permissions/gmail/doc_sync.py index 6f1bae674..3a58e5aca 100644 --- a/backend/ee/onyx/external_permissions/gmail/doc_sync.py +++ b/backend/ee/onyx/external_permissions/gmail/doc_sync.py @@ -1,3 +1,4 @@ +from collections.abc import Generator from datetime import datetime from datetime import timezone @@ -34,7 +35,7 @@ def _get_slim_doc_generator( def gmail_doc_sync( cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None, -) -> list[DocExternalAccess]: +) -> Generator[DocExternalAccess, None, None]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -48,7 +49,6 @@ def gmail_doc_sync( cc_pair, gmail_connector, callback=callback ) - document_external_access: list[DocExternalAccess] = [] for slim_doc_batch in slim_doc_generator: for slim_doc in slim_doc_batch: if callback: @@ -60,17 +60,14 @@ def gmail_doc_sync( if slim_doc.perm_sync_data is None: logger.warning(f"No permissions found for document {slim_doc.id}") continue + if user_email := slim_doc.perm_sync_data.get("user_email"): ext_access = ExternalAccess( external_user_emails=set([user_email]), external_user_group_ids=set(), is_public=False, ) - document_external_access.append( - DocExternalAccess( - doc_id=slim_doc.id, - external_access=ext_access, - ) + yield DocExternalAccess( + doc_id=slim_doc.id, + external_access=ext_access, ) - - return document_external_access diff --git a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py index 8d3df7fa8..6b7aed3ec 100644 --- a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py @@ -1,3 +1,4 @@ +from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any @@ -147,7 +148,7 @@ def _get_permissions_from_slim_doc( def gdrive_doc_sync( cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None, -) -> list[DocExternalAccess]: +) -> Generator[DocExternalAccess, None, None]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -161,7 +162,6 @@ def gdrive_doc_sync( slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector) - document_external_accesses = [] for slim_doc_batch in slim_doc_generator: for slim_doc in slim_doc_batch: if callback: @@ -174,10 +174,7 @@ def gdrive_doc_sync( google_drive_connector=google_drive_connector, slim_doc=slim_doc, ) - document_external_accesses.append( - DocExternalAccess( - external_access=ext_access, - doc_id=slim_doc.id, - ) + yield DocExternalAccess( + external_access=ext_access, + doc_id=slim_doc.id, ) - return document_external_accesses diff --git a/backend/ee/onyx/external_permissions/slack/doc_sync.py b/backend/ee/onyx/external_permissions/slack/doc_sync.py index 0ae9b58cc..ce8b883a2 100644 --- a/backend/ee/onyx/external_permissions/slack/doc_sync.py +++ b/backend/ee/onyx/external_permissions/slack/doc_sync.py @@ -1,3 +1,5 @@ +from collections.abc import Generator + from slack_sdk import WebClient from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map @@ -14,35 +16,6 @@ from onyx.utils.logger import setup_logger logger = setup_logger() -def _get_slack_document_ids_and_channels( - cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None -) -> dict[str, list[str]]: - slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config) - slack_connector.load_credentials(cc_pair.credential.credential_json) - - slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback) - - channel_doc_map: dict[str, list[str]] = {} - for doc_metadata_batch in slim_doc_generator: - for doc_metadata in doc_metadata_batch: - if doc_metadata.perm_sync_data is None: - continue - channel_id = doc_metadata.perm_sync_data["channel_id"] - if channel_id not in channel_doc_map: - channel_doc_map[channel_id] = [] - channel_doc_map[channel_id].append(doc_metadata.id) - - if callback: - if callback.should_stop(): - raise RuntimeError( - "_get_slack_document_ids_and_channels: Stop signal detected" - ) - - callback.progress("_get_slack_document_ids_and_channels", 1) - - return channel_doc_map - - def _fetch_workspace_permissions( user_id_to_email_map: dict[str, str], ) -> ExternalAccess: @@ -122,10 +95,37 @@ def _fetch_channel_permissions( return channel_permissions +def _get_slack_document_access( + cc_pair: ConnectorCredentialPair, + channel_permissions: dict[str, ExternalAccess], + callback: IndexingHeartbeatInterface | None, +) -> Generator[DocExternalAccess, None, None]: + slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config) + slack_connector.load_credentials(cc_pair.credential.credential_json) + + slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback) + + for doc_metadata_batch in slim_doc_generator: + for doc_metadata in doc_metadata_batch: + if doc_metadata.perm_sync_data is None: + continue + channel_id = doc_metadata.perm_sync_data["channel_id"] + yield DocExternalAccess( + external_access=channel_permissions[channel_id], + doc_id=doc_metadata.id, + ) + + if callback: + if callback.should_stop(): + raise RuntimeError("_get_slack_document_access: Stop signal detected") + + callback.progress("_get_slack_document_access", 1) + + def slack_doc_sync( cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None, -) -> list[DocExternalAccess]: +) -> Generator[DocExternalAccess, None, None]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -136,9 +136,12 @@ def slack_doc_sync( token=cc_pair.credential.credential_json["slack_bot_token"] ) user_id_to_email_map = fetch_user_id_to_email_map(slack_client) - channel_doc_map = _get_slack_document_ids_and_channels( - cc_pair=cc_pair, callback=callback - ) + if not user_id_to_email_map: + raise ValueError( + "No user id to email map found. Please check to make sure that " + "your Slack bot token has the `users:read.email` scope" + ) + workspace_permissions = _fetch_workspace_permissions( user_id_to_email_map=user_id_to_email_map, ) @@ -148,18 +151,8 @@ def slack_doc_sync( user_id_to_email_map=user_id_to_email_map, ) - document_external_accesses = [] - for channel_id, ext_access in channel_permissions.items(): - doc_ids = channel_doc_map.get(channel_id) - if not doc_ids: - # No documents found for channel the channel_id - continue - - for doc_id in doc_ids: - document_external_accesses.append( - DocExternalAccess( - external_access=ext_access, - doc_id=doc_id, - ) - ) - return document_external_accesses + yield from _get_slack_document_access( + cc_pair=cc_pair, + channel_permissions=channel_permissions, + callback=callback, + ) diff --git a/backend/ee/onyx/external_permissions/sync_params.py b/backend/ee/onyx/external_permissions/sync_params.py index 9f8ed9681..28e27652c 100644 --- a/backend/ee/onyx/external_permissions/sync_params.py +++ b/backend/ee/onyx/external_permissions/sync_params.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from collections.abc import Generator from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY @@ -23,7 +24,7 @@ DocSyncFuncType = Callable[ ConnectorCredentialPair, IndexingHeartbeatInterface | None, ], - list[DocExternalAccess], + Generator[DocExternalAccess, None, None], ] GroupSyncFuncType = Callable[ diff --git a/backend/model_server/constants.py b/backend/model_server/constants.py index d026d4a76..83d1c33d8 100644 --- a/backend/model_server/constants.py +++ b/backend/model_server/constants.py @@ -3,6 +3,7 @@ from shared_configs.enums import EmbedTextType MODEL_WARM_UP_STRING = "hi " * 512 +INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16 DEFAULT_OPENAI_MODEL = "text-embedding-3-small" DEFAULT_COHERE_MODEL = "embed-english-light-v3.0" DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct" diff --git a/backend/model_server/custom_models.py b/backend/model_server/custom_models.py index db8ba5d0c..ea15721ae 100644 --- a/backend/model_server/custom_models.py +++ b/backend/model_server/custom_models.py @@ -1,11 +1,14 @@ +import numpy as np import torch import torch.nn.functional as F from fastapi import APIRouter from huggingface_hub import snapshot_download # type: ignore +from setfit import SetFitModel # type: ignore[import] from transformers import AutoTokenizer # type: ignore from transformers import BatchEncoding # type: ignore from transformers import PreTrainedTokenizer # type: ignore +from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING from model_server.constants import MODEL_WARM_UP_STRING from model_server.onyx_torch_model import ConnectorClassifier from model_server.onyx_torch_model import HybridClassifier @@ -13,11 +16,22 @@ from model_server.utils import simple_log_function_time from onyx.utils.logger import setup_logger from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG +from shared_configs.configs import ( + INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH, +) +from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX +from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN +from shared_configs.configs import ( + INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE, +) from shared_configs.configs import INDEXING_ONLY +from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG +from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION from shared_configs.configs import INTENT_MODEL_TAG from shared_configs.configs import INTENT_MODEL_VERSION from shared_configs.model_server_models import ConnectorClassificationRequest from shared_configs.model_server_models import ConnectorClassificationResponse +from shared_configs.model_server_models import ContentClassificationPrediction from shared_configs.model_server_models import IntentRequest from shared_configs.model_server_models import IntentResponse @@ -31,6 +45,10 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None _INTENT_TOKENIZER: AutoTokenizer | None = None _INTENT_MODEL: HybridClassifier | None = None +_INFORMATION_CONTENT_MODEL: SetFitModel | None = None + +_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version! + def get_connector_classifier_tokenizer() -> AutoTokenizer: global _CONNECTOR_CLASSIFIER_TOKENIZER @@ -85,7 +103,7 @@ def get_intent_model_tokenizer() -> AutoTokenizer: def get_local_intent_model( model_name_or_path: str = INTENT_MODEL_VERSION, - tag: str = INTENT_MODEL_TAG, + tag: str | None = INTENT_MODEL_TAG, ) -> HybridClassifier: global _INTENT_MODEL if _INTENT_MODEL is None: @@ -102,7 +120,9 @@ def get_local_intent_model( try: # Attempt to download the model snapshot logger.notice(f"Downloading model snapshot for {model_name_or_path}") - local_path = snapshot_download(repo_id=model_name_or_path, revision=tag) + local_path = snapshot_download( + repo_id=model_name_or_path, revision=tag, local_files_only=False + ) _INTENT_MODEL = HybridClassifier.from_pretrained(local_path) except Exception as e: logger.error( @@ -112,6 +132,44 @@ def get_local_intent_model( return _INTENT_MODEL +def get_local_information_content_model( + model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION, + tag: str | None = INFORMATION_CONTENT_MODEL_TAG, +) -> SetFitModel: + global _INFORMATION_CONTENT_MODEL + if _INFORMATION_CONTENT_MODEL is None: + try: + # Calculate where the cache should be, then load from local if available + logger.notice( + f"Loading content information model from local cache: {model_name_or_path}" + ) + local_path = snapshot_download( + repo_id=model_name_or_path, revision=tag, local_files_only=True + ) + _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path) + logger.notice( + f"Loaded content information model from local cache: {local_path}" + ) + except Exception as e: + logger.warning(f"Failed to load content information model directly: {e}") + try: + # Attempt to download the model snapshot + logger.notice( + f"Downloading content information model snapshot for {model_name_or_path}" + ) + local_path = snapshot_download( + repo_id=model_name_or_path, revision=tag, local_files_only=False + ) + _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path) + except Exception as e: + logger.error( + f"Failed to load content information model even after attempted snapshot download: {e}" + ) + raise + + return _INFORMATION_CONTENT_MODEL + + def tokenize_connector_classification_query( connectors: list[str], query: str, @@ -195,6 +253,13 @@ def warm_up_intent_model() -> None: ) +def warm_up_information_content_model() -> None: + logger.notice("Warming up Content Model") # TODO: add version if needed + + information_content_model = get_local_information_content_model() + information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING) + + @simple_log_function_time() def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]: intent_model = get_local_intent_model() @@ -218,6 +283,117 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]: return intent_probabilities.tolist(), token_positive_probs +@simple_log_function_time() +def run_content_classification_inference( + text_inputs: list[str], +) -> list[ContentClassificationPrediction]: + """ + Assign a score to the segments in question. The model stored in get_local_information_content_model() + creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale. + In the code outside of the model/inference model servers that score will be converted into the actual + boost factor. + """ + + def _prob_to_score(prob: float) -> float: + """ + Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model! + """ + _MIN_BASE_SCORE = 0.25 + _MAX_BASE_SCORE = 0.75 + if prob < _MIN_BASE_SCORE: + raw_score = 0.0 + elif prob < _MAX_BASE_SCORE: + raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE) + else: + raw_score = 1.0 + return ( + INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN + + ( + INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX + - INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN + ) + * raw_score + ) + + _BATCH_SIZE = 32 + content_model = get_local_information_content_model() + + # Process inputs in batches + all_output_classes: list[int] = [] + all_base_output_probabilities: list[float] = [] + + for i in range(0, len(text_inputs), _BATCH_SIZE): + batch = text_inputs[i : i + _BATCH_SIZE] + batch_with_prefix = [] + batch_indices = [] + + # Pre-allocate results for this batch + batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch) + batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch) + + # Pre-process batch to handle long input exceptions + for j, text in enumerate(batch): + if len(text) == 0: + # if no input, treat as non-informative from the model's perspective + batch_output_classes[j] = np.array(0) + batch_probabilities[j] = np.array(0.0) + logger.warning("Input for Content Information Model is empty") + + elif ( + len(text.split()) + <= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + ): + # if input is short, use the model + batch_with_prefix.append( + _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text + ) + batch_indices.append(j) + else: + # if longer than cutoff, treat as informative (stay with default), but issue warning + logger.warning("Input for Content Information Model too long") + + if batch_with_prefix: # Only run model if we have valid inputs + # Get predictions for the batch + model_output_classes = content_model(batch_with_prefix) + model_output_probabilities = content_model.predict_proba(batch_with_prefix) + + # Place results in the correct positions + for idx, batch_idx in enumerate(batch_indices): + batch_output_classes[batch_idx] = model_output_classes[idx].numpy() + batch_probabilities[batch_idx] = model_output_probabilities[idx][ + 1 + ].numpy() # x[1] is prob of the positive class + + all_output_classes.extend([int(x) for x in batch_output_classes]) + all_base_output_probabilities.extend([float(x) for x in batch_probabilities]) + + logits = [ + np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100) + for p in all_base_output_probabilities + ] + scaled_logits = [ + logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE + for logit in logits + ] + output_probabilities_with_temp = [ + np.exp(scaled_logit) / (1 + np.exp(scaled_logit)) + for scaled_logit in scaled_logits + ] + + prediction_scores = [ + _prob_to_score(p_temp) for p_temp in output_probabilities_with_temp + ] + + content_classification_predictions = [ + ContentClassificationPrediction( + predicted_label=predicted_label, content_boost_factor=output_score + ) + for predicted_label, output_score in zip(all_output_classes, prediction_scores) + ] + + return content_classification_predictions + + def map_keywords( input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool] ) -> list[str]: @@ -362,3 +538,10 @@ async def process_analysis_request( is_keyword, keywords = run_analysis(intent_request) return IntentResponse(is_keyword=is_keyword, keywords=keywords) + + +@router.post("/content-classification") +async def process_content_classification_request( + content_classification_requests: list[str], +) -> list[ContentClassificationPrediction]: + return run_content_classification_inference(content_classification_requests) diff --git a/backend/model_server/main.py b/backend/model_server/main.py index 0a6b56be1..3a6a56297 100644 --- a/backend/model_server/main.py +++ b/backend/model_server/main.py @@ -13,6 +13,7 @@ from sentry_sdk.integrations.starlette import StarletteIntegration from transformers import logging as transformer_logging # type:ignore from model_server.custom_models import router as custom_models_router +from model_server.custom_models import warm_up_information_content_model from model_server.custom_models import warm_up_intent_model from model_server.encoders import router as encoders_router from model_server.management_endpoints import router as management_router @@ -74,9 +75,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.notice(f"Torch Threads: {torch.get_num_threads()}") if not INDEXING_ONLY: + logger.notice( + "The intent model should run on the model server. The information content model should not run here." + ) warm_up_intent_model() else: - logger.notice("This model server should only run document indexing.") + logger.notice( + "The content information model should run on the indexing model server. The intent model should not run here." + ) + warm_up_information_content_model() yield diff --git a/backend/onyx/access/models.py b/backend/onyx/access/models.py index e2364fcf7..411e53a03 100644 --- a/backend/onyx/access/models.py +++ b/backend/onyx/access/models.py @@ -20,7 +20,7 @@ class ExternalAccess: class DocExternalAccess: """ This is just a class to wrap the external access and the document ID - together. It's used for syncing document permissions to Redis. + together. It's used for syncing document permissions to Vespa. """ external_access: ExternalAccess diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index 05a49ccab..5bb060a63 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -105,6 +105,7 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from onyx.utils.variable_functionality import fetch_versioned_implementation from shared_configs.configs import async_return_default_schema from shared_configs.configs import MULTI_TENANT +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import get_current_tenant_id @@ -593,7 +594,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): tenant_id = fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_tenant_id_for_email", - None, + POSTGRES_DEFAULT_SCHEMA, )( email=email, ) diff --git a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py index b308e5a18..ba2b68aa1 100644 --- a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py @@ -453,23 +453,23 @@ def connector_permission_sync_generator_task( redis_connector.permissions.set_fence(new_payload) callback = PermissionSyncCallback(redis_connector, lock, r) - document_external_accesses: list[DocExternalAccess] = doc_sync_func( - cc_pair, callback - ) + document_external_accesses = doc_sync_func(cc_pair, callback) task_logger.info( f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}" ) - tasks_generated = redis_connector.permissions.generate_tasks( - celery_app=self.app, - lock=lock, - new_permissions=document_external_accesses, - source_string=source_type, - connector_id=cc_pair.connector.id, - credential_id=cc_pair.credential.id, - ) - if tasks_generated is None: - return None + + tasks_generated = 0 + for doc_external_access in document_external_accesses: + redis_connector.permissions.generate_tasks( + celery_app=self.app, + lock=lock, + new_permissions=[doc_external_access], + source_string=source_type, + connector_id=cc_pair.connector.id, + credential_id=cc_pair.credential.id, + ) + tasks_generated += 1 task_logger.info( f"RedisConnector.permissions.generate_tasks finished. " diff --git a/backend/onyx/background/celery/tasks/vespa/tasks.py b/backend/onyx/background/celery/tasks/vespa/tasks.py index a3c84acec..f319d271e 100644 --- a/backend/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/onyx/background/celery/tasks/vespa/tasks.py @@ -563,6 +563,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> access=doc_access, boost=doc.boost, hidden=doc.hidden, + # aggregated_boost_factor=doc.aggregated_boost_factor, ) # update Vespa. OK if doc doesn't exist. Raises exception otherwise. diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index 612cacc3e..12b485b87 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -53,6 +53,9 @@ from onyx.httpx.httpx_pool import HttpxPool from onyx.indexing.embedder import DefaultIndexingEmbedder from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.indexing_pipeline import build_indexing_pipeline +from onyx.natural_language_processing.search_nlp_models import ( + InformationContentClassificationModel, +) from onyx.utils.logger import setup_logger from onyx.utils.logger import TaskAttemptSingleton from onyx.utils.telemetry import create_milestone_and_report @@ -348,6 +351,8 @@ def _run_indexing( callback=callback, ) + information_content_classification_model = InformationContentClassificationModel() + document_index = get_default_document_index( index_attempt_start.search_settings, None, @@ -356,6 +361,7 @@ def _run_indexing( indexing_pipeline = build_indexing_pipeline( embedder=embedding_model, + information_content_classification_model=information_content_classification_model, document_index=document_index, ignore_time_skip=( ctx.from_beginning diff --git a/backend/onyx/configs/model_configs.py b/backend/onyx/configs/model_configs.py index 0c85661d6..b59857b51 100644 --- a/backend/onyx/configs/model_configs.py +++ b/backend/onyx/configs/model_configs.py @@ -132,3 +132,10 @@ if _LITELLM_EXTRA_BODY_RAW: LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW) except Exception: pass + +# Whether and how to lower scores for short chunks w/o relevant context +# Evaluated via custom ML model + +USE_INFORMATION_CONTENT_CLASSIFICATION = ( + os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true" +) diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index c3a085b06..f7a27a826 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -4,6 +4,7 @@ from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import Any +from urllib.parse import urlparse from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore @@ -59,7 +60,7 @@ def _extract_str_list_from_comma_str(string: str | None) -> list[str]: def _extract_ids_from_urls(urls: list[str]) -> list[str]: - return [url.split("/")[-1] for url in urls] + return [urlparse(url).path.strip("/").split("/")[-1] for url in urls] def _convert_single_file( diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index 2ef36b954..138c3b422 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -220,7 +220,6 @@ def thread_to_doc( source=DocumentSource.SLACK, semantic_identifier=doc_sem_id, doc_updated_at=get_latest_message_time(thread), - title="", # slack docs don't really have a "title" primary_owners=valid_experts, metadata={"Channel": channel["name"]}, ) @@ -413,8 +412,8 @@ def _get_all_doc_ids( callback=callback, ) - message_ts_set: set[str] = set() for message_batch in channel_message_batches: + slim_doc_batch: list[SlimDocument] = [] for message in message_batch: if msg_filter_func(message): continue @@ -422,18 +421,17 @@ def _get_all_doc_ids( # The document id is the channel id and the ts of the first message in the thread # Since we already have the first message of the thread, we dont have to # fetch the thread for id retrieval, saving time and API calls - message_ts_set.add(message["ts"]) - channel_metadata_list: list[SlimDocument] = [] - for message_ts in message_ts_set: - channel_metadata_list.append( - SlimDocument( - id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts), - perm_sync_data={"channel_id": channel_id}, + slim_doc_batch.append( + SlimDocument( + id=_build_doc_id( + channel_id=channel_id, thread_ts=message["ts"] + ), + perm_sync_data={"channel_id": channel_id}, + ) ) - ) - yield channel_metadata_list + yield slim_doc_batch def _process_message( diff --git a/backend/onyx/db/chunk.py b/backend/onyx/db/chunk.py new file mode 100644 index 000000000..4eed6430c --- /dev/null +++ b/backend/onyx/db/chunk.py @@ -0,0 +1,63 @@ +from datetime import datetime +from datetime import timezone + +from sqlalchemy import delete +from sqlalchemy.orm import Session + +from onyx.db.models import ChunkStats +from onyx.indexing.models import UpdatableChunkData + + +def update_chunk_boost_components__no_commit( + chunk_data: list[UpdatableChunkData], + db_session: Session, +) -> None: + """Updates the chunk_boost_components for chunks in the database. + + Args: + chunk_data: List of dicts containing chunk_id, document_id, and boost_score + db_session: SQLAlchemy database session + """ + if not chunk_data: + return + + for data in chunk_data: + chunk_in_doc_id = int(data.chunk_id) + if chunk_in_doc_id < 0: + raise ValueError(f"Chunk ID is empty for chunk {data}") + + chunk_document_id = f"{data.document_id}" f"__{chunk_in_doc_id}" + chunk_stats = ( + db_session.query(ChunkStats) + .filter( + ChunkStats.id == chunk_document_id, + ) + .first() + ) + + score = data.boost_score + + if chunk_stats: + chunk_stats.information_content_boost = score + chunk_stats.last_modified = datetime.now(timezone.utc) + db_session.add(chunk_stats) + else: + # do not save new chunks with a neutral boost score + if score == 1.0: + continue + # Create new record + chunk_stats = ChunkStats( + document_id=data.document_id, + chunk_in_doc_id=chunk_in_doc_id, + information_content_boost=score, + ) + db_session.add(chunk_stats) + + +def delete_chunk_stats_by_connector_credential_pair__no_commit( + db_session: Session, document_ids: list[str] +) -> None: + """This deletes just chunk stats in postgres.""" + stmt = delete(ChunkStats).where(ChunkStats.document_id.in_(document_ids)) + + db_session.execute(stmt) diff --git a/backend/onyx/db/document.py b/backend/onyx/db/document.py index e40844766..33d106380 100644 --- a/backend/onyx/db/document.py +++ b/backend/onyx/db/document.py @@ -23,6 +23,7 @@ from sqlalchemy.sql.expression import null from onyx.configs.constants import DEFAULT_BOOST from onyx.configs.constants import DocumentSource +from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.engine import get_session_context_manager from onyx.db.enums import AccessType @@ -562,6 +563,18 @@ def delete_documents_complete__no_commit( db_session: Session, document_ids: list[str] ) -> None: """This completely deletes the documents from the db, including all foreign key relationships""" + + # Start by deleting the chunk stats for the documents + delete_chunk_stats_by_connector_credential_pair__no_commit( + db_session=db_session, + document_ids=document_ids, + ) + + delete_chunk_stats_by_connector_credential_pair__no_commit( + db_session=db_session, + document_ids=document_ids, + ) + delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids) delete_document_feedback_for_documents__no_commit( document_ids=document_ids, db_session=db_session diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index c0865e523..cbe3b1be3 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -591,6 +591,55 @@ class Document(Base): ) +class ChunkStats(Base): + __tablename__ = "chunk_stats" + # NOTE: if more sensitive data is added here for display, make sure to add user/group permission + + # this should correspond to the ID of the document + # (as is passed around in Onyx) + id: Mapped[str] = mapped_column( + NullFilteredString, + primary_key=True, + default=lambda context: ( + f"{context.get_current_parameters()['document_id']}" + f"__{context.get_current_parameters()['chunk_in_doc_id']}" + ), + index=True, + ) + + # Reference to parent document + document_id: Mapped[str] = mapped_column( + NullFilteredString, ForeignKey("document.id"), nullable=False, index=True + ) + + chunk_in_doc_id: Mapped[int] = mapped_column( + Integer, + nullable=False, + ) + + information_content_boost: Mapped[float | None] = mapped_column( + Float, nullable=True + ) + + last_modified: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=False, index=True, default=func.now() + ) + last_synced: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True, index=True + ) + + __table_args__ = ( + Index( + "ix_chunk_sync_status", + last_modified, + last_synced, + ), + UniqueConstraint( + "document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk" + ), + ) + + class Tag(Base): __tablename__ = "tag" diff --git a/backend/onyx/document_index/interfaces.py b/backend/onyx/document_index/interfaces.py index 463abbc95..e34cbc9eb 100644 --- a/backend/onyx/document_index/interfaces.py +++ b/backend/onyx/document_index/interfaces.py @@ -101,6 +101,7 @@ class VespaDocumentFields: document_sets: set[str] | None = None boost: float | None = None hidden: bool | None = None + aggregated_chunk_boost_factor: float | None = None @dataclass diff --git a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd index d5d5220f8..70980852a 100644 --- a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd @@ -80,6 +80,11 @@ schema DANSWER_CHUNK_NAME { indexing: summary | attribute rank: filter } + # Field to indicate whether a short chunk is a low content chunk + field aggregated_chunk_boost_factor type float { + indexing: attribute + } + # Needs to have a separate Attribute list for efficient filtering field metadata_list type array { indexing: summary | attribute @@ -142,6 +147,11 @@ schema DANSWER_CHUNK_NAME { expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0) } + function inline aggregated_chunk_boost() { + # Aggregated boost factor, currently only used for information content classification + expression: if(isNan(attribute(aggregated_chunk_boost_factor)) == 1, 1.0, attribute(aggregated_chunk_boost_factor)) + } + # Document score decays from 1 to 0.75 as age of last updated time increases function inline recency_bias() { expression: max(1 / (1 + query(decay_factor) * document_age), 0.75) @@ -199,6 +209,8 @@ schema DANSWER_CHUNK_NAME { * document_boost # Decay factor based on time document was last updated * recency_bias + # Boost based on aggregated boost calculation + * aggregated_chunk_boost } rerank-count: 1000 } @@ -210,6 +222,7 @@ schema DANSWER_CHUNK_NAME { closeness(field, embeddings) document_boost recency_bias + aggregated_chunk_boost closest(embeddings) } } diff --git a/backend/onyx/document_index/vespa/indexing_utils.py b/backend/onyx/document_index/vespa/indexing_utils.py index 81fc2a0d4..ab08dc4d1 100644 --- a/backend/onyx/document_index/vespa/indexing_utils.py +++ b/backend/onyx/document_index/vespa/indexing_utils.py @@ -22,6 +22,7 @@ from onyx.document_index.vespa.shared_utils.utils import ( replace_invalid_doc_id_characters, ) from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST +from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR from onyx.document_index.vespa_constants import BLURB from onyx.document_index.vespa_constants import BOOST from onyx.document_index.vespa_constants import CHUNK_ID @@ -201,6 +202,7 @@ def _index_vespa_chunk( DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets}, IMAGE_FILE_NAME: chunk.image_file_name, BOOST: chunk.boost, + AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor, } if multitenant: diff --git a/backend/onyx/document_index/vespa_constants.py b/backend/onyx/document_index/vespa_constants.py index 15f889f3c..8a32fb721 100644 --- a/backend/onyx/document_index/vespa_constants.py +++ b/backend/onyx/document_index/vespa_constants.py @@ -72,6 +72,7 @@ METADATA = "metadata" METADATA_LIST = "metadata_list" METADATA_SUFFIX = "metadata_suffix" BOOST = "boost" +AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor" DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch PRIMARY_OWNERS = "primary_owners" SECONDARY_OWNERS = "secondary_owners" @@ -97,6 +98,7 @@ YQL_BASE = ( f"{SECTION_CONTINUATION}, " f"{IMAGE_FILE_NAME}, " f"{BOOST}, " + f"{AGGREGATED_CHUNK_BOOST_FACTOR}, " f"{HIDDEN}, " f"{DOC_UPDATED_AT}, " f"{PRIMARY_OWNERS}, " diff --git a/backend/onyx/indexing/content_classification.py b/backend/onyx/indexing/content_classification.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index f4a6e0075..3967b7f7c 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -11,6 +11,7 @@ from onyx.access.models import DocumentAccess from onyx.configs.app_configs import MAX_DOCUMENT_CHARS from onyx.configs.constants import DEFAULT_BOOST from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled +from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( get_experts_stores_representations, ) @@ -22,6 +23,7 @@ from onyx.connectors.models import IndexAttemptMetadata from onyx.connectors.models import IndexingDocument from onyx.connectors.models import Section from onyx.connectors.models import TextSection +from onyx.db.chunk import update_chunk_boost_components__no_commit from onyx.db.document import fetch_chunk_counts_for_documents from onyx.db.document import get_documents_by_ids from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit @@ -52,10 +54,19 @@ from onyx.indexing.embedder import IndexingEmbedder from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.models import DocAwareChunk from onyx.indexing.models import DocMetadataAwareIndexChunk +from onyx.indexing.models import IndexChunk +from onyx.indexing.models import UpdatableChunkData from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff from onyx.llm.factory import get_default_llm_with_vision +from onyx.natural_language_processing.search_nlp_models import ( + InformationContentClassificationModel, +) from onyx.utils.logger import setup_logger from onyx.utils.timing import log_function_time +from shared_configs.configs import ( + INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH, +) + logger = setup_logger() @@ -136,6 +147,72 @@ def _upsert_documents_in_db( ) +def _get_aggregated_chunk_boost_factor( + chunks: list[IndexChunk], + information_content_classification_model: InformationContentClassificationModel, +) -> list[float]: + """Calculates the aggregated boost factor for a chunk based on its content.""" + + short_chunk_content_dict = { + chunk_num: chunk.content + for chunk_num, chunk in enumerate(chunks) + if len(chunk.content.split()) + <= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + } + short_chunk_contents = list(short_chunk_content_dict.values()) + short_chunk_keys = list(short_chunk_content_dict.keys()) + + try: + predictions = information_content_classification_model.predict( + short_chunk_contents + ) + # Create a mapping of chunk positions to their scores + score_map = { + short_chunk_keys[i]: prediction.content_boost_factor + for i, prediction in enumerate(predictions) + } + # Default to 1.0 for longer chunks, use predicted score for short chunks + chunk_content_scores = [score_map.get(i, 1.0) for i in range(len(chunks))] + + return chunk_content_scores + + except Exception as e: + logger.exception( + f"Error predicting content classification for chunks: {e}. Falling back to individual examples." + ) + + chunks_with_scores: list[IndexChunk] = [] + chunk_content_scores = [] + + for chunk in chunks: + if ( + len(chunk.content.split()) + > INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + ): + chunk_content_scores.append(1.0) + chunks_with_scores.append(chunk) + continue + + try: + chunk_content_scores.append( + information_content_classification_model.predict([chunk.content])[ + 0 + ].content_boost_factor + ) + chunks_with_scores.append(chunk) + except Exception as e: + logger.exception( + f"Error predicting content classification for chunk: {e}." + ) + + raise Exception( + f"Failed to predict content classification for chunk {chunk.chunk_id} " + f"from document {chunk.source_document.id}" + ) from e + + return chunk_content_scores + + def get_doc_ids_to_update( documents: list[Document], db_docs: list[DBDocument] ) -> list[Document]: @@ -165,6 +242,7 @@ def index_doc_batch_with_handler( *, chunker: Chunker, embedder: IndexingEmbedder, + information_content_classification_model: InformationContentClassificationModel, document_index: DocumentIndex, document_batch: list[Document], index_attempt_metadata: IndexAttemptMetadata, @@ -176,6 +254,7 @@ def index_doc_batch_with_handler( index_pipeline_result = index_doc_batch( chunker=chunker, embedder=embedder, + information_content_classification_model=information_content_classification_model, document_index=document_index, document_batch=document_batch, index_attempt_metadata=index_attempt_metadata, @@ -450,6 +529,7 @@ def index_doc_batch( document_batch: list[Document], chunker: Chunker, embedder: IndexingEmbedder, + information_content_classification_model: InformationContentClassificationModel, document_index: DocumentIndex, index_attempt_metadata: IndexAttemptMetadata, db_session: Session, @@ -526,7 +606,23 @@ def index_doc_batch( else ([], []) ) + chunk_content_scores = ( + _get_aggregated_chunk_boost_factor( + chunks_with_embeddings, information_content_classification_model + ) + if USE_INFORMATION_CONTENT_CLASSIFICATION + else [1.0] * len(chunks_with_embeddings) + ) + updatable_ids = [doc.id for doc in ctx.updatable_docs] + updatable_chunk_data = [ + UpdatableChunkData( + chunk_id=chunk.chunk_id, + document_id=chunk.source_document.id, + boost_score=score, + ) + for chunk, score in zip(chunks_with_embeddings, chunk_content_scores) + ] # Acquires a lock on the documents so that no other process can modify them # NOTE: don't need to acquire till here, since this is when the actual race condition @@ -579,8 +675,9 @@ def index_doc_batch( else DEFAULT_BOOST ), tenant_id=tenant_id, + aggregated_chunk_boost_factor=chunk_content_scores[chunk_num], ) - for chunk in chunks_with_embeddings + for chunk_num, chunk in enumerate(chunks_with_embeddings) ] logger.debug( @@ -665,6 +762,11 @@ def index_doc_batch( db_session=db_session, ) + # save the chunk boost components to postgres + update_chunk_boost_components__no_commit( + chunk_data=updatable_chunk_data, db_session=db_session + ) + db_session.commit() result = IndexingPipelineResult( @@ -680,6 +782,7 @@ def index_doc_batch( def build_indexing_pipeline( *, embedder: IndexingEmbedder, + information_content_classification_model: InformationContentClassificationModel, document_index: DocumentIndex, db_session: Session, tenant_id: str, @@ -703,6 +806,7 @@ def build_indexing_pipeline( index_doc_batch_with_handler, chunker=chunker, embedder=embedder, + information_content_classification_model=information_content_classification_model, document_index=document_index, ignore_time_skip=ignore_time_skip, db_session=db_session, diff --git a/backend/onyx/indexing/models.py b/backend/onyx/indexing/models.py index 5dffe1b08..686ac2942 100644 --- a/backend/onyx/indexing/models.py +++ b/backend/onyx/indexing/models.py @@ -83,13 +83,16 @@ class DocMetadataAwareIndexChunk(IndexChunk): document_sets: all document sets the source document for this chunk is a part of. This is used for filtering / personas. boost: influences the ranking of this chunk at query time. Positive -> ranked higher, - negative -> ranked lower. + negative -> ranked lower. Not included in aggregated boost calculation + for legacy reasons. + aggregated_chunk_boost_factor: represents the aggregated chunk-level boost (currently: information content) """ tenant_id: str access: "DocumentAccess" document_sets: set[str] boost: int + aggregated_chunk_boost_factor: float @classmethod def from_index_chunk( @@ -98,6 +101,7 @@ class DocMetadataAwareIndexChunk(IndexChunk): access: "DocumentAccess", document_sets: set[str], boost: int, + aggregated_chunk_boost_factor: float, tenant_id: str, ) -> "DocMetadataAwareIndexChunk": index_chunk_data = index_chunk.model_dump() @@ -106,6 +110,7 @@ class DocMetadataAwareIndexChunk(IndexChunk): access=access, document_sets=document_sets, boost=boost, + aggregated_chunk_boost_factor=aggregated_chunk_boost_factor, tenant_id=tenant_id, ) @@ -179,3 +184,9 @@ class IndexingSetting(EmbeddingModelDetail): class MultipassConfig(BaseModel): multipass_indexing: bool enable_large_chunks: bool + + +class UpdatableChunkData(BaseModel): + chunk_id: int + document_id: str + boost_score: float diff --git a/backend/onyx/llm/llm_provider_options.py b/backend/onyx/llm/llm_provider_options.py index 2ac8a3b9f..84ad6eca2 100644 --- a/backend/onyx/llm/llm_provider_options.py +++ b/backend/onyx/llm/llm_provider_options.py @@ -56,7 +56,9 @@ BEDROCK_PROVIDER_NAME = "bedrock" # models BEDROCK_MODEL_NAMES = [ model - for model in litellm.bedrock_models + # bedrock_converse_models are just extensions of the bedrock_models, not sure why + # litellm has split them into two lists :( + for model in litellm.bedrock_models + litellm.bedrock_converse_models if "/" not in model and "embed" not in model ][::-1] diff --git a/backend/onyx/natural_language_processing/search_nlp_models.py b/backend/onyx/natural_language_processing/search_nlp_models.py index 3a7fcdf6f..ad8ac25f9 100644 --- a/backend/onyx/natural_language_processing/search_nlp_models.py +++ b/backend/onyx/natural_language_processing/search_nlp_models.py @@ -29,6 +29,8 @@ from onyx.natural_language_processing.exceptions import ( from onyx.natural_language_processing.utils import get_tokenizer from onyx.natural_language_processing.utils import tokenizer_trim_content from onyx.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import INDEXING_MODEL_SERVER_PORT from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.enums import EmbeddingProvider @@ -36,9 +38,11 @@ from shared_configs.enums import EmbedTextType from shared_configs.enums import RerankerProvider from shared_configs.model_server_models import ConnectorClassificationRequest from shared_configs.model_server_models import ConnectorClassificationResponse +from shared_configs.model_server_models import ContentClassificationPrediction from shared_configs.model_server_models import Embedding from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedResponse +from shared_configs.model_server_models import InformationContentClassificationResponses from shared_configs.model_server_models import IntentRequest from shared_configs.model_server_models import IntentResponse from shared_configs.model_server_models import RerankRequest @@ -377,6 +381,31 @@ class QueryAnalysisModel: return response_model.is_keyword, response_model.keywords +class InformationContentClassificationModel: + def __init__( + self, + model_server_host: str = INDEXING_MODEL_SERVER_HOST, + model_server_port: int = INDEXING_MODEL_SERVER_PORT, + ) -> None: + model_server_url = build_model_server_url(model_server_host, model_server_port) + self.content_server_endpoint = ( + model_server_url + "/custom/content-classification" + ) + + def predict( + self, + queries: list[str], + ) -> list[ContentClassificationPrediction]: + response = requests.post(self.content_server_endpoint, json=queries) + response.raise_for_status() + + model_responses = InformationContentClassificationResponses( + information_content_classifications=response.json() + ) + + return model_responses.information_content_classifications + + class ConnectorClassificationModel: def __init__( self, diff --git a/backend/onyx/seeding/load_docs.py b/backend/onyx/seeding/load_docs.py index 44a6a6112..a3aa99ea6 100644 --- a/backend/onyx/seeding/load_docs.py +++ b/backend/onyx/seeding/load_docs.py @@ -98,6 +98,7 @@ def _create_indexable_chunks( boost=DEFAULT_BOOST, large_chunk_id=None, image_file_name=None, + aggregated_chunk_boost_factor=1.0, ) chunks.append(chunk) diff --git a/backend/onyx/server/onyx_api/ingestion.py b/backend/onyx/server/onyx_api/ingestion.py index ec4eeac8d..aa0f72ad1 100644 --- a/backend/onyx/server/onyx_api/ingestion.py +++ b/backend/onyx/server/onyx_api/ingestion.py @@ -19,6 +19,9 @@ from onyx.db.search_settings import get_secondary_search_settings from onyx.document_index.factory import get_default_document_index from onyx.indexing.embedder import DefaultIndexingEmbedder from onyx.indexing.indexing_pipeline import build_indexing_pipeline +from onyx.natural_language_processing.search_nlp_models import ( + InformationContentClassificationModel, +) from onyx.server.onyx_api.models import DocMinimalInfo from onyx.server.onyx_api.models import IngestionDocument from onyx.server.onyx_api.models import IngestionResult @@ -102,8 +105,11 @@ def upsert_ingestion_doc( search_settings=search_settings ) + information_content_classification_model = InformationContentClassificationModel() + indexing_pipeline = build_indexing_pipeline( embedder=index_embedding_model, + information_content_classification_model=information_content_classification_model, document_index=curr_doc_index, ignore_time_skip=True, db_session=db_session, @@ -138,6 +144,7 @@ def upsert_ingestion_doc( sec_ind_pipeline = build_indexing_pipeline( embedder=new_index_embedding_model, + information_content_classification_model=information_content_classification_model, document_index=sec_doc_index, ignore_time_skip=True, db_session=db_session, diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 016d14c23..9bae3b12e 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -25,7 +25,7 @@ google-auth-oauthlib==1.0.0 httpcore==1.0.5 httpx[http2]==0.27.0 httpx-oauth==0.15.1 -huggingface-hub==0.20.1 +huggingface-hub==0.29.0 inflection==0.5.1 jira==3.5.1 jsonref==1.1.0 @@ -38,7 +38,7 @@ langchainhub==0.1.21 langgraph==0.2.72 langgraph-checkpoint==2.0.13 langgraph-sdk==0.1.44 -litellm==1.61.16 +litellm==1.63.8 lxml==5.3.0 lxml_html_clean==0.2.2 llama-index==0.9.45 @@ -47,7 +47,7 @@ msal==1.28.0 nltk==3.8.1 Office365-REST-Python-Client==2.5.9 oauthlib==3.2.2 -openai==1.61.0 +openai==1.66.3 openpyxl==3.1.2 playwright==1.41.2 psutil==5.9.5 @@ -71,6 +71,7 @@ requests==2.32.2 requests-oauthlib==1.3.1 retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image rfc3986==1.5.0 +setfit==1.1.1 simple-salesforce==1.12.6 slack-sdk==3.20.2 SQLAlchemy[mypy]==2.0.15 @@ -78,7 +79,7 @@ starlette==0.36.3 supervisor==4.2.5 tiktoken==0.7.0 timeago==1.0.16 -transformers==4.39.2 +transformers==4.49.0 unstructured==0.15.1 unstructured-client==0.25.4 uvicorn==0.21.1 diff --git a/backend/requirements/dev.txt b/backend/requirements/dev.txt index 73972d638..4dcf44a6d 100644 --- a/backend/requirements/dev.txt +++ b/backend/requirements/dev.txt @@ -14,7 +14,7 @@ pytest-asyncio==0.22.0 pytest==7.4.4 reorder-python-imports==3.9.0 ruff==0.0.286 -sentence-transformers==2.6.1 +sentence-transformers==3.4.1 trafilatura==1.12.2 types-beautifulsoup4==4.12.0.3 types-html5lib==1.1.11.13 diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index b4d4a9f06..bcae01052 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -7,9 +7,10 @@ openai==1.61.0 pydantic==2.8.2 retry==0.9.2 safetensors==0.4.2 -sentence-transformers==2.6.1 +sentence-transformers==3.4.1 +setfit==1.1.1 torch==2.2.0 -transformers==4.39.2 +transformers==4.49.0 uvicorn==0.21.1 voyageai==0.2.3 litellm==1.61.16 diff --git a/backend/scripts/document_seeding_prep.py b/backend/scripts/document_seeding_prep.py index 4b643ef4e..c4d9637cb 100644 --- a/backend/scripts/document_seeding_prep.py +++ b/backend/scripts/document_seeding_prep.py @@ -161,17 +161,21 @@ overview_doc = SeedPresaveDocument( url="https://docs.onyx.app/more/use_cases/overview", title=overview_title, content=overview, - title_embedding=model.encode(f"search_document: {overview_title}"), - content_embedding=model.encode(f"search_document: {overview_title}\n{overview}"), + title_embedding=list(model.encode(f"search_document: {overview_title}")), + content_embedding=list( + model.encode(f"search_document: {overview_title}\n{overview}") + ), ) enterprise_search_doc = SeedPresaveDocument( url="https://docs.onyx.app/more/use_cases/enterprise_search", title=enterprise_search_title, content=enterprise_search_1, - title_embedding=model.encode(f"search_document: {enterprise_search_title}"), - content_embedding=model.encode( - f"search_document: {enterprise_search_title}\n{enterprise_search_1}" + title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")), + content_embedding=list( + model.encode( + f"search_document: {enterprise_search_title}\n{enterprise_search_1}" + ) ), ) @@ -179,9 +183,11 @@ enterprise_search_doc_2 = SeedPresaveDocument( url="https://docs.onyx.app/more/use_cases/enterprise_search", title=enterprise_search_title, content=enterprise_search_2, - title_embedding=model.encode(f"search_document: {enterprise_search_title}"), - content_embedding=model.encode( - f"search_document: {enterprise_search_title}\n{enterprise_search_2}" + title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")), + content_embedding=list( + model.encode( + f"search_document: {enterprise_search_title}\n{enterprise_search_2}" + ) ), chunk_ind=1, ) @@ -190,9 +196,9 @@ ai_platform_doc = SeedPresaveDocument( url="https://docs.onyx.app/more/use_cases/ai_platform", title=ai_platform_title, content=ai_platform, - title_embedding=model.encode(f"search_document: {ai_platform_title}"), - content_embedding=model.encode( - f"search_document: {ai_platform_title}\n{ai_platform}" + title_embedding=list(model.encode(f"search_document: {ai_platform_title}")), + content_embedding=list( + model.encode(f"search_document: {ai_platform_title}\n{ai_platform}") ), ) @@ -200,9 +206,9 @@ customer_support_doc = SeedPresaveDocument( url="https://docs.onyx.app/more/use_cases/support", title=customer_support_title, content=customer_support, - title_embedding=model.encode(f"search_document: {customer_support_title}"), - content_embedding=model.encode( - f"search_document: {customer_support_title}\n{customer_support}" + title_embedding=list(model.encode(f"search_document: {customer_support_title}")), + content_embedding=list( + model.encode(f"search_document: {customer_support_title}\n{customer_support}") ), ) @@ -210,17 +216,17 @@ sales_doc = SeedPresaveDocument( url="https://docs.onyx.app/more/use_cases/sales", title=sales_title, content=sales, - title_embedding=model.encode(f"search_document: {sales_title}"), - content_embedding=model.encode(f"search_document: {sales_title}\n{sales}"), + title_embedding=list(model.encode(f"search_document: {sales_title}")), + content_embedding=list(model.encode(f"search_document: {sales_title}\n{sales}")), ) operations_doc = SeedPresaveDocument( url="https://docs.onyx.app/more/use_cases/operations", title=operations_title, content=operations, - title_embedding=model.encode(f"search_document: {operations_title}"), - content_embedding=model.encode( - f"search_document: {operations_title}\n{operations}" + title_embedding=list(model.encode(f"search_document: {operations_title}")), + content_embedding=list( + model.encode(f"search_document: {operations_title}\n{operations}") ), ) diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py index b94b413e2..da1edd8db 100644 --- a/backend/scripts/query_time_check/seed_dummy_docs.py +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -99,6 +99,7 @@ def generate_dummy_chunk( ), document_sets={document_set for document_set in document_set_names}, boost=random.randint(-1, 1), + aggregated_chunk_boost_factor=random.random(), tenant_id=POSTGRES_DEFAULT_SCHEMA, ) diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index b21c53d69..13e7ba03e 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -23,9 +23,11 @@ INDEXING_MODEL_SERVER_PORT = int( # Onyx custom Deep Learning Models CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model" CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0" -INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier" -INTENT_MODEL_TAG = "v1.0.3" - +INTENT_MODEL_VERSION = "onyx-dot-app/hybrid-intent-token-classifier" +# INTENT_MODEL_TAG = "v1.0.3" +INTENT_MODEL_TAG: str | None = None +INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model" +INFORMATION_CONTENT_MODEL_TAG: str | None = None # Bi-Encoder, other details DOC_EMBEDDING_CONTEXT_SIZE = 512 @@ -277,3 +279,20 @@ SUPPORTED_EMBEDDING_MODELS = [ index_name="danswer_chunk_intfloat_multilingual_e5_small", ), ] +# Maximum (least severe) downgrade factor for chunks above the cutoff +INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = float( + os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX") or 1.0 +) +# Minimum (most severe) downgrade factor for short chunks below the cutoff if no content +INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float( + os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.7 +) +# Temperature for the information content classification model +INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float( + os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0 +) +# Cutoff below which we start using the information content classification model +# (cutoff length number itself is still considered 'short')) +INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int( + os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10 +) diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 644f315fa..4c9c1be1e 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -4,6 +4,7 @@ from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType from shared_configs.enums import RerankerProvider + Embedding = list[float] @@ -73,7 +74,20 @@ class IntentResponse(BaseModel): keywords: list[str] +class InformationContentClassificationRequests(BaseModel): + queries: list[str] + + class SupportedEmbeddingModel(BaseModel): name: str dim: int index_name: str + + +class ContentClassificationPrediction(BaseModel): + predicted_label: int + content_boost_factor: float + + +class InformationContentClassificationResponses(BaseModel): + information_content_classifications: list[ContentClassificationPrediction] diff --git a/backend/tests/unit/model_server/test_custom_models.py b/backend/tests/unit/model_server/test_custom_models.py new file mode 100644 index 000000000..b7eacffdd --- /dev/null +++ b/backend/tests/unit/model_server/test_custom_models.py @@ -0,0 +1,145 @@ +from typing import Any +from unittest.mock import Mock +from unittest.mock import patch + +import numpy as np +import numpy.typing as npt +import pytest + +from model_server.custom_models import run_content_classification_inference +from shared_configs.configs import ( + INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH, +) +from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX +from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN +from shared_configs.model_server_models import ContentClassificationPrediction + + +@pytest.fixture +def mock_content_model() -> Mock: + model = Mock() + + # Create actual numpy arrays for the mock returns + predict_output = np.array( + [1, 0] * 50, dtype=np.int64 + ) # Pre-allocate enough elements + proba_output = np.array( + [[0.3, 0.7], [0.7, 0.3]] * 50, dtype=np.float64 + ) # Pre-allocate enough elements + + # Create a mock tensor that has a numpy method and supports indexing + class MockTensor: + def __init__(self, value: npt.NDArray[Any]) -> None: + self.value = value + + def numpy(self) -> npt.NDArray[Any]: + return self.value + + def __getitem__(self, idx: Any) -> Any: + result = self.value[idx] + # Wrap scalar values back in MockTensor + if isinstance(result, (np.float64, np.int64)): + return MockTensor(np.array([result])) + return MockTensor(result) + + # Mock the direct call to return a MockTensor for each input + def model_call(inputs: list[str]) -> list[MockTensor]: + batch_size = len(inputs) + return [MockTensor(predict_output[i : i + 1]) for i in range(batch_size)] + + model.side_effect = model_call + + # Mock predict_proba to return MockTensor-wrapped numpy array + def predict_proba_call(x: list[str]) -> MockTensor: + batch_size = len(x) + return MockTensor(proba_output[:batch_size]) + + model.predict_proba.side_effect = predict_proba_call + + return model + + +@patch("model_server.custom_models.get_local_information_content_model") +def test_run_content_classification_inference( + mock_get_model: Mock, + mock_content_model: Mock, +) -> None: + """ + Test the content classification inference function. + Verifies that the function correctly processes text inputs and returns appropriate predictions. + """ + # Setup + mock_get_model.return_value = mock_content_model + + test_inputs = [ + "Imagine a short text with content", + "Imagine a short text without content", + "x " + * ( + INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1 + ), # Long input that exceeds maximal length for when the model should be applied + "", # Empty input + ] + + # Execute + results = run_content_classification_inference(test_inputs) + + # Assert + assert len(results) == len(test_inputs) + assert all(isinstance(r, ContentClassificationPrediction) for r in results) + + # Check each prediction has expected attributes and ranges + for result_num, result in enumerate(results): + assert hasattr(result, "predicted_label") + assert hasattr(result, "content_boost_factor") + assert isinstance(result.predicted_label, int) + assert isinstance(result.content_boost_factor, float) + assert ( + INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN + <= result.content_boost_factor + <= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX + ) + if result_num == 2: + assert ( + result.content_boost_factor + == INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX + ) + assert result.predicted_label == 1 + elif result_num == 3: + assert ( + result.content_boost_factor + == INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN + ) + assert result.predicted_label == 0 + + # Verify model handling of long inputs + mock_content_model.predict_proba.reset_mock() + long_input = ["x " * 1000] # Definitely exceeds MAX_LENGTH + results = run_content_classification_inference(long_input) + assert len(results) == 1 + assert ( + mock_content_model.predict_proba.call_count == 0 + ) # Should skip model call for too-long input + + +@patch("model_server.custom_models.get_local_information_content_model") +def test_batch_processing( + mock_get_model: Mock, + mock_content_model: Mock, +) -> None: + """ + Test that the function correctly handles batch processing of inputs. + """ + # Setup + mock_get_model.return_value = mock_content_model + + # Create test input larger than batch size + test_inputs = [f"Test input {i}" for i in range(40)] # > BATCH_SIZE (32) + + # Execute + results = run_content_classification_inference(test_inputs) + + # Assert + assert len(results) == 40 + # Verify batching occurred (should have called predict_proba twice) + assert mock_content_model.predict_proba.call_count == 2 diff --git a/backend/tests/unit/onyx/indexing/test_indexing_pipeline.py b/backend/tests/unit/onyx/indexing/test_indexing_pipeline.py index 1a4ab701d..a46d455a2 100644 --- a/backend/tests/unit/onyx/indexing/test_indexing_pipeline.py +++ b/backend/tests/unit/onyx/indexing/test_indexing_pipeline.py @@ -1,12 +1,24 @@ from typing import cast from typing import List +from unittest.mock import Mock + +import pytest from onyx.configs.app_configs import MAX_DOCUMENT_CHARS from onyx.connectors.models import Document from onyx.connectors.models import DocumentSource from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection +from onyx.indexing.indexing_pipeline import _get_aggregated_chunk_boost_factor from onyx.indexing.indexing_pipeline import filter_documents +from onyx.indexing.models import ChunkEmbedding +from onyx.indexing.models import IndexChunk +from onyx.natural_language_processing.search_nlp_models import ( + ContentClassificationPrediction, +) +from shared_configs.configs import ( + INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH, +) def create_test_document( @@ -123,3 +135,117 @@ def test_filter_documents_multiple_documents() -> None: def test_filter_documents_empty_batch() -> None: result = filter_documents([]) assert len(result) == 0 + + +# Tests for get_aggregated_boost_factor + + +def create_test_chunk( + content: str, chunk_id: int = 0, doc_id: str = "test_doc" +) -> IndexChunk: + doc = Document( + id=doc_id, + semantic_identifier="test doc", + sections=[], + source=DocumentSource.FILE, + metadata={}, + ) + return IndexChunk( + chunk_id=chunk_id, + content=content, + source_document=doc, + blurb=content[:50], # First 50 chars as blurb + source_links={0: "test_link"}, + section_continuation=False, + title_prefix="", + metadata_suffix_semantic="", + metadata_suffix_keyword="", + mini_chunk_texts=None, + large_chunk_id=None, + large_chunk_reference_ids=[], + embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]), + title_embedding=None, + image_file_name=None, + ) + + +def test_get_aggregated_boost_factor() -> None: + # Create test chunks - mix of short and long content + chunks = [ + create_test_chunk("Short content", 0), + create_test_chunk( + "Long " * (INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1), 1 + ), + create_test_chunk("Another short chunk", 2), + ] + + # Mock the classification model + mock_model = Mock() + mock_model.predict.return_value = [ + ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8), + ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.9), + ] + + # Execute the function + boost_scores = _get_aggregated_chunk_boost_factor( + chunks=chunks, information_content_classification_model=mock_model + ) + + # Assertions + assert len(boost_scores) == 3 + + # Check that long content got default boost + assert boost_scores[1] == 1.0 + + # Check that short content got predicted boosts + assert boost_scores[0] == 0.8 + assert boost_scores[2] == 0.9 + + # Verify model was only called once with the short chunks + mock_model.predict.assert_called_once() + assert len(mock_model.predict.call_args[0][0]) == 2 + + +def test_get_aggregated_boost_factorilure() -> None: + chunks = [ + create_test_chunk("Short content 1", 0), + create_test_chunk("Short content 2", 1), + ] + + # Mock model to fail on batch prediction but succeed on individual predictions + mock_model = Mock() + mock_model.predict.side_effect = [ + Exception("Batch prediction failed"), # First call fails + [ + ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.7) + ], # Individual calls succeed + [ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8)], + ] + + # Execute + boost_scores = _get_aggregated_chunk_boost_factor( + chunks=chunks, information_content_classification_model=mock_model + ) + + # Assertions + assert len(boost_scores) == 2 + assert boost_scores == [0.7, 0.8] + + +def test_get_aggregated_boost_factor_individual_failure() -> None: + chunks = [ + create_test_chunk("Short content", 0), + create_test_chunk("Short content", 1), + ] + + # Mock model to fail on both batch and individual prediction + mock_model = Mock() + mock_model.predict.side_effect = Exception("Prediction failed") + + # Execute and verify it raises an exception + with pytest.raises(Exception) as exc_info: + _get_aggregated_chunk_boost_factor( + chunks=chunks, information_content_classification_model=mock_model + ) + + assert "Failed to predict content classification for chunk" in str(exc_info.value) diff --git a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx index 3a5565013..f8cb5b6cb 100644 --- a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx @@ -284,7 +284,9 @@ export function LLMProviderUpdateForm({ subtext="The model to use by default for this provider unless otherwise specified." label="Default Model" options={llmProviderDescriptor.llm_names.map((name) => ({ - name: getDisplayNameForModel(name), + // don't clean up names here to give admins descriptive names / handle duplicates + // like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0 + name: name, value: name, }))} maxHeight="max-h-56" @@ -314,7 +316,9 @@ export function LLMProviderUpdateForm({ the Default Model configured above.`} label="[Optional] Fast Model" options={llmProviderDescriptor.llm_names.map((name) => ({ - name: getDisplayNameForModel(name), + // don't clean up names here to give admins descriptive names / handle duplicates + // like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0 + name: name, value: name, }))} includeDefault @@ -355,7 +359,9 @@ export function LLMProviderUpdateForm({ options={llmProviderDescriptor.llm_names.map( (name) => ({ value: name, - label: getDisplayNameForModel(name), + // don't clean up names here to give admins descriptive names / handle duplicates + // like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0 + label: name, }) )} onChange={(selected) => diff --git a/web/src/components/credentials/lib.ts b/web/src/components/credentials/lib.ts index a9ab3ca19..d7ef8e1ea 100644 --- a/web/src/components/credentials/lib.ts +++ b/web/src/components/credentials/lib.ts @@ -16,7 +16,15 @@ export function createValidationSchema(json_values: Record) { const displayName = getDisplayNameForCredentialKey(key); - if (json_values[key] === null) { + if (typeof json_values[key] === "boolean") { + // Ensure false is considered valid + schemaFields[key] = Yup.boolean() + .nullable() + .default(false) + .transform((value, originalValue) => + originalValue === undefined ? false : value + ); + } else if (json_values[key] === null) { // Field is optional: schemaFields[key] = Yup.string() .trim()