Merge branch 'main' of https://github.com/onyx-dot-app/onyx into bugfix/confluence-filters

This commit is contained in:
Richard Kuo (Onyx) 2025-03-14 13:53:49 -07:00
commit 1d354b85db
43 changed files with 1012 additions and 149 deletions

View File

@ -31,7 +31,8 @@ RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \

View File

@ -0,0 +1,51 @@
"""add chunk stats table
Revision ID: 3781a5eb12cb
Revises: df46c75b714e
Create Date: 2025-03-10 10:02:30.586666
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3781a5eb12cb"
down_revision = "df46c75b714e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chunk_stats",
sa.Column("id", sa.String(), primary_key=True, index=True),
sa.Column(
"document_id",
sa.String(),
sa.ForeignKey("document.id"),
nullable=False,
index=True,
),
sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False),
sa.Column("information_content_boost", sa.Float(), nullable=True),
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
index=True,
server_default=sa.func.now(),
),
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True),
sa.UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
op.create_index(
"ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"]
)
def downgrade() -> None:
op.drop_index("ix_chunk_sync_status", table_name="chunk_stats")
op.drop_table("chunk_stats")

View File

@ -2,6 +2,7 @@
Rules defined here:
https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.html
"""
from collections.abc import Generator
from typing import Any
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
@ -263,13 +264,11 @@ def _fetch_all_page_restrictions(
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
For all pages, if a page has restrictions, then use those restrictions.
Otherwise, use the space's restrictions.
"""
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
@ -286,11 +285,9 @@ def _fetch_all_page_restrictions(
confluence_client=confluence_client,
perm_sync_data=slim_doc.perm_sync_data,
):
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
)
# If there are restrictions, then we don't need to use the space's restrictions
continue
@ -324,11 +321,9 @@ def _fetch_all_page_restrictions(
continue
# If there are no restrictions, then use the space's restrictions
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
if (
not space_permissions.is_public
@ -342,13 +337,12 @@ def _fetch_all_page_restrictions(
)
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@ -387,7 +381,7 @@ def confluence_doc_sync(
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
return _fetch_all_page_restrictions(
yield from _fetch_all_page_restrictions(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,

View File

@ -1,3 +1,4 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
@ -34,7 +35,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@ -48,7 +49,6 @@ def gmail_doc_sync(
cc_pair, gmail_connector, callback=callback
)
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
@ -60,17 +60,14 @@ def gmail_doc_sync(
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue
if user_email := slim_doc.perm_sync_data.get("user_email"):
ext_access = ExternalAccess(
external_user_emails=set([user_email]),
external_user_group_ids=set(),
is_public=False,
)
document_external_access.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
)
return document_external_access

View File

@ -1,3 +1,4 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
@ -147,7 +148,7 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@ -161,7 +162,6 @@ def gdrive_doc_sync(
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
@ -174,10 +174,7 @@ def gdrive_doc_sync(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,
)
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=slim_doc.id,
)
yield DocExternalAccess(
external_access=ext_access,
doc_id=slim_doc.id,
)
return document_external_accesses

View File

@ -1,3 +1,5 @@
from collections.abc import Generator
from slack_sdk import WebClient
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
@ -14,35 +16,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
if channel_id not in channel_doc_map:
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_metadata.id)
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
return channel_doc_map
def _fetch_workspace_permissions(
user_id_to_email_map: dict[str, str],
) -> ExternalAccess:
@ -122,10 +95,37 @@ def _fetch_channel_permissions(
return channel_permissions
def _get_slack_document_access(
cc_pair: ConnectorCredentialPair,
channel_permissions: dict[str, ExternalAccess],
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
yield DocExternalAccess(
external_access=channel_permissions[channel_id],
doc_id=doc_metadata.id,
)
if callback:
if callback.should_stop():
raise RuntimeError("_get_slack_document_access: Stop signal detected")
callback.progress("_get_slack_document_access", 1)
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@ -136,9 +136,12 @@ def slack_doc_sync(
token=cc_pair.credential.credential_json["slack_bot_token"]
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair, callback=callback
)
if not user_id_to_email_map:
raise ValueError(
"No user id to email map found. Please check to make sure that "
"your Slack bot token has the `users:read.email` scope"
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,
)
@ -148,18 +151,8 @@ def slack_doc_sync(
user_id_to_email_map=user_id_to_email_map,
)
document_external_accesses = []
for channel_id, ext_access in channel_permissions.items():
doc_ids = channel_doc_map.get(channel_id)
if not doc_ids:
# No documents found for channel the channel_id
continue
for doc_id in doc_ids:
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=doc_id,
)
)
return document_external_accesses
yield from _get_slack_document_access(
cc_pair=cc_pair,
channel_permissions=channel_permissions,
callback=callback,
)

View File

@ -1,4 +1,5 @@
from collections.abc import Callable
from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
@ -23,7 +24,7 @@ DocSyncFuncType = Callable[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
Generator[DocExternalAccess, None, None],
]
GroupSyncFuncType = Callable[

View File

@ -3,6 +3,7 @@ from shared_configs.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"

View File

@ -1,11 +1,14 @@
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from setfit import SetFitModel # type: ignore[import]
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.onyx_torch_model import ConnectorClassifier
from model_server.onyx_torch_model import HybridClassifier
@ -13,11 +16,22 @@ from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
)
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
@ -31,6 +45,10 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
@ -85,7 +103,7 @@ def get_intent_model_tokenizer() -> AutoTokenizer:
def get_local_intent_model(
model_name_or_path: str = INTENT_MODEL_VERSION,
tag: str = INTENT_MODEL_TAG,
tag: str | None = INTENT_MODEL_TAG,
) -> HybridClassifier:
global _INTENT_MODEL
if _INTENT_MODEL is None:
@ -102,7 +120,9 @@ def get_local_intent_model(
try:
# Attempt to download the model snapshot
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
except Exception as e:
logger.error(
@ -112,6 +132,44 @@ def get_local_intent_model(
return _INTENT_MODEL
def get_local_information_content_model(
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
) -> SetFitModel:
global _INFORMATION_CONTENT_MODEL
if _INFORMATION_CONTENT_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
logger.notice(
f"Loading content information model from local cache: {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
logger.notice(
f"Loaded content information model from local cache: {local_path}"
)
except Exception as e:
logger.warning(f"Failed to load content information model directly: {e}")
try:
# Attempt to download the model snapshot
logger.notice(
f"Downloading content information model snapshot for {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
except Exception as e:
logger.error(
f"Failed to load content information model even after attempted snapshot download: {e}"
)
raise
return _INFORMATION_CONTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
@ -195,6 +253,13 @@ def warm_up_intent_model() -> None:
)
def warm_up_information_content_model() -> None:
logger.notice("Warming up Content Model") # TODO: add version if needed
information_content_model = get_local_information_content_model()
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
@ -218,6 +283,117 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
return intent_probabilities.tolist(), token_positive_probs
@simple_log_function_time()
def run_content_classification_inference(
text_inputs: list[str],
) -> list[ContentClassificationPrediction]:
"""
Assign a score to the segments in question. The model stored in get_local_information_content_model()
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
In the code outside of the model/inference model servers that score will be converted into the actual
boost factor.
"""
def _prob_to_score(prob: float) -> float:
"""
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
"""
_MIN_BASE_SCORE = 0.25
_MAX_BASE_SCORE = 0.75
if prob < _MIN_BASE_SCORE:
raw_score = 0.0
elif prob < _MAX_BASE_SCORE:
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
else:
raw_score = 1.0
return (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
+ (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
)
* raw_score
)
_BATCH_SIZE = 32
content_model = get_local_information_content_model()
# Process inputs in batches
all_output_classes: list[int] = []
all_base_output_probabilities: list[float] = []
for i in range(0, len(text_inputs), _BATCH_SIZE):
batch = text_inputs[i : i + _BATCH_SIZE]
batch_with_prefix = []
batch_indices = []
# Pre-allocate results for this batch
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
# Pre-process batch to handle long input exceptions
for j, text in enumerate(batch):
if len(text) == 0:
# if no input, treat as non-informative from the model's perspective
batch_output_classes[j] = np.array(0)
batch_probabilities[j] = np.array(0.0)
logger.warning("Input for Content Information Model is empty")
elif (
len(text.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
# if input is short, use the model
batch_with_prefix.append(
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
)
batch_indices.append(j)
else:
# if longer than cutoff, treat as informative (stay with default), but issue warning
logger.warning("Input for Content Information Model too long")
if batch_with_prefix: # Only run model if we have valid inputs
# Get predictions for the batch
model_output_classes = content_model(batch_with_prefix)
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
# Place results in the correct positions
for idx, batch_idx in enumerate(batch_indices):
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
batch_probabilities[batch_idx] = model_output_probabilities[idx][
1
].numpy() # x[1] is prob of the positive class
all_output_classes.extend([int(x) for x in batch_output_classes])
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
logits = [
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
for p in all_base_output_probabilities
]
scaled_logits = [
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
for logit in logits
]
output_probabilities_with_temp = [
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
for scaled_logit in scaled_logits
]
prediction_scores = [
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
]
content_classification_predictions = [
ContentClassificationPrediction(
predicted_label=predicted_label, content_boost_factor=output_score
)
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
]
return content_classification_predictions
def map_keywords(
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
) -> list[str]:
@ -362,3 +538,10 @@ async def process_analysis_request(
is_keyword, keywords = run_analysis(intent_request)
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
@router.post("/content-classification")
async def process_content_classification_request(
content_classification_requests: list[str],
) -> list[ContentClassificationPrediction]:
return run_content_classification_inference(content_classification_requests)

View File

@ -13,6 +13,7 @@ from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
@ -74,9 +75,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model()
else:
logger.notice("This model server should only run document indexing.")
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
yield

View File

@ -20,7 +20,7 @@ class ExternalAccess:
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
together. It's used for syncing document permissions to Vespa.
"""
external_access: ExternalAccess

View File

@ -105,6 +105,7 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
@ -593,7 +594,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_tenant_id_for_email",
None,
POSTGRES_DEFAULT_SCHEMA,
)(
email=email,
)

View File

@ -453,23 +453,23 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(redis_connector, lock, r)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
cc_pair, callback
)
document_external_accesses = doc_sync_func(cc_pair, callback)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=document_external_accesses,
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
if tasks_generated is None:
return None
tasks_generated = 0
for doc_external_access in document_external_accesses:
redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=[doc_external_access],
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
tasks_generated += 1
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "

View File

@ -563,6 +563,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
# aggregated_boost_factor=doc.aggregated_boost_factor,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.

View File

@ -53,6 +53,9 @@ from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
@ -348,6 +351,8 @@ def _run_indexing(
callback=callback,
)
information_content_classification_model = InformationContentClassificationModel()
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
@ -356,6 +361,7 @@ def _run_indexing(
indexing_pipeline = build_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning

View File

@ -132,3 +132,10 @@ if _LITELLM_EXTRA_BODY_RAW:
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
except Exception:
pass
# Whether and how to lower scores for short chunks w/o relevant context
# Evaluated via custom ML model
USE_INFORMATION_CONTENT_CLASSIFICATION = (
os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true"
)

View File

@ -4,6 +4,7 @@ from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any
from urllib.parse import urlparse
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
@ -59,7 +60,7 @@ def _extract_str_list_from_comma_str(string: str | None) -> list[str]:
def _extract_ids_from_urls(urls: list[str]) -> list[str]:
return [url.split("/")[-1] for url in urls]
return [urlparse(url).path.strip("/").split("/")[-1] for url in urls]
def _convert_single_file(

View File

@ -220,7 +220,6 @@ def thread_to_doc(
source=DocumentSource.SLACK,
semantic_identifier=doc_sem_id,
doc_updated_at=get_latest_message_time(thread),
title="", # slack docs don't really have a "title"
primary_owners=valid_experts,
metadata={"Channel": channel["name"]},
)
@ -413,8 +412,8 @@ def _get_all_doc_ids(
callback=callback,
)
message_ts_set: set[str] = set()
for message_batch in channel_message_batches:
slim_doc_batch: list[SlimDocument] = []
for message in message_batch:
if msg_filter_func(message):
continue
@ -422,18 +421,17 @@ def _get_all_doc_ids(
# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
message_ts_set.add(message["ts"])
channel_metadata_list: list[SlimDocument] = []
for message_ts in message_ts_set:
channel_metadata_list.append(
SlimDocument(
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
perm_sync_data={"channel_id": channel_id},
slim_doc_batch.append(
SlimDocument(
id=_build_doc_id(
channel_id=channel_id, thread_ts=message["ts"]
),
perm_sync_data={"channel_id": channel_id},
)
)
)
yield channel_metadata_list
yield slim_doc_batch
def _process_message(

63
backend/onyx/db/chunk.py Normal file
View File

@ -0,0 +1,63 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy.orm import Session
from onyx.db.models import ChunkStats
from onyx.indexing.models import UpdatableChunkData
def update_chunk_boost_components__no_commit(
chunk_data: list[UpdatableChunkData],
db_session: Session,
) -> None:
"""Updates the chunk_boost_components for chunks in the database.
Args:
chunk_data: List of dicts containing chunk_id, document_id, and boost_score
db_session: SQLAlchemy database session
"""
if not chunk_data:
return
for data in chunk_data:
chunk_in_doc_id = int(data.chunk_id)
if chunk_in_doc_id < 0:
raise ValueError(f"Chunk ID is empty for chunk {data}")
chunk_document_id = f"{data.document_id}" f"__{chunk_in_doc_id}"
chunk_stats = (
db_session.query(ChunkStats)
.filter(
ChunkStats.id == chunk_document_id,
)
.first()
)
score = data.boost_score
if chunk_stats:
chunk_stats.information_content_boost = score
chunk_stats.last_modified = datetime.now(timezone.utc)
db_session.add(chunk_stats)
else:
# do not save new chunks with a neutral boost score
if score == 1.0:
continue
# Create new record
chunk_stats = ChunkStats(
document_id=data.document_id,
chunk_in_doc_id=chunk_in_doc_id,
information_content_boost=score,
)
db_session.add(chunk_stats)
def delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This deletes just chunk stats in postgres."""
stmt = delete(ChunkStats).where(ChunkStats.document_id.in_(document_ids))
db_session.execute(stmt)

View File

@ -23,6 +23,7 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
@ -562,6 +563,18 @@ def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
# Start by deleting the chunk stats for the documents
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session

View File

@ -591,6 +591,55 @@ class Document(Base):
)
class ChunkStats(Base):
__tablename__ = "chunk_stats"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
# this should correspond to the ID of the document
# (as is passed around in Onyx)
id: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
default=lambda context: (
f"{context.get_current_parameters()['document_id']}"
f"__{context.get_current_parameters()['chunk_in_doc_id']}"
),
index=True,
)
# Reference to parent document
document_id: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("document.id"), nullable=False, index=True
)
chunk_in_doc_id: Mapped[int] = mapped_column(
Integer,
nullable=False,
)
information_content_boost: Mapped[float | None] = mapped_column(
Float, nullable=True
)
last_modified: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=False, index=True, default=func.now()
)
last_synced: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, index=True
)
__table_args__ = (
Index(
"ix_chunk_sync_status",
last_modified,
last_synced,
),
UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
class Tag(Base):
__tablename__ = "tag"

View File

@ -101,6 +101,7 @@ class VespaDocumentFields:
document_sets: set[str] | None = None
boost: float | None = None
hidden: bool | None = None
aggregated_chunk_boost_factor: float | None = None
@dataclass

View File

@ -80,6 +80,11 @@ schema DANSWER_CHUNK_NAME {
indexing: summary | attribute
rank: filter
}
# Field to indicate whether a short chunk is a low content chunk
field aggregated_chunk_boost_factor type float {
indexing: attribute
}
# Needs to have a separate Attribute list for efficient filtering
field metadata_list type array<string> {
indexing: summary | attribute
@ -142,6 +147,11 @@ schema DANSWER_CHUNK_NAME {
expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0)
}
function inline aggregated_chunk_boost() {
# Aggregated boost factor, currently only used for information content classification
expression: if(isNan(attribute(aggregated_chunk_boost_factor)) == 1, 1.0, attribute(aggregated_chunk_boost_factor))
}
# Document score decays from 1 to 0.75 as age of last updated time increases
function inline recency_bias() {
expression: max(1 / (1 + query(decay_factor) * document_age), 0.75)
@ -199,6 +209,8 @@ schema DANSWER_CHUNK_NAME {
* document_boost
# Decay factor based on time document was last updated
* recency_bias
# Boost based on aggregated boost calculation
* aggregated_chunk_boost
}
rerank-count: 1000
}
@ -210,6 +222,7 @@ schema DANSWER_CHUNK_NAME {
closeness(field, embeddings)
document_boost
recency_bias
aggregated_chunk_boost
closest(embeddings)
}
}

View File

@ -22,6 +22,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_ID
@ -201,6 +202,7 @@ def _index_vespa_chunk(
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
IMAGE_FILE_NAME: chunk.image_file_name,
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}
if multitenant:

View File

@ -72,6 +72,7 @@ METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
BOOST = "boost"
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
PRIMARY_OWNERS = "primary_owners"
SECONDARY_OWNERS = "secondary_owners"
@ -97,6 +98,7 @@ YQL_BASE = (
f"{SECTION_CONTINUATION}, "
f"{IMAGE_FILE_NAME}, "
f"{BOOST}, "
f"{AGGREGATED_CHUNK_BOOST_FACTOR}, "
f"{HIDDEN}, "
f"{DOC_UPDATED_AT}, "
f"{PRIMARY_OWNERS}, "

View File

@ -11,6 +11,7 @@ from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
@ -22,6 +23,7 @@ from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.db.chunk import update_chunk_boost_components__no_commit
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
@ -52,10 +54,19 @@ from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
logger = setup_logger()
@ -136,6 +147,72 @@ def _upsert_documents_in_db(
)
def _get_aggregated_chunk_boost_factor(
chunks: list[IndexChunk],
information_content_classification_model: InformationContentClassificationModel,
) -> list[float]:
"""Calculates the aggregated boost factor for a chunk based on its content."""
short_chunk_content_dict = {
chunk_num: chunk.content
for chunk_num, chunk in enumerate(chunks)
if len(chunk.content.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
}
short_chunk_contents = list(short_chunk_content_dict.values())
short_chunk_keys = list(short_chunk_content_dict.keys())
try:
predictions = information_content_classification_model.predict(
short_chunk_contents
)
# Create a mapping of chunk positions to their scores
score_map = {
short_chunk_keys[i]: prediction.content_boost_factor
for i, prediction in enumerate(predictions)
}
# Default to 1.0 for longer chunks, use predicted score for short chunks
chunk_content_scores = [score_map.get(i, 1.0) for i in range(len(chunks))]
return chunk_content_scores
except Exception as e:
logger.exception(
f"Error predicting content classification for chunks: {e}. Falling back to individual examples."
)
chunks_with_scores: list[IndexChunk] = []
chunk_content_scores = []
for chunk in chunks:
if (
len(chunk.content.split())
> INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
chunk_content_scores.append(1.0)
chunks_with_scores.append(chunk)
continue
try:
chunk_content_scores.append(
information_content_classification_model.predict([chunk.content])[
0
].content_boost_factor
)
chunks_with_scores.append(chunk)
except Exception as e:
logger.exception(
f"Error predicting content classification for chunk: {e}."
)
raise Exception(
f"Failed to predict content classification for chunk {chunk.chunk_id} "
f"from document {chunk.source_document.id}"
) from e
return chunk_content_scores
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
@ -165,6 +242,7 @@ def index_doc_batch_with_handler(
*,
chunker: Chunker,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
@ -176,6 +254,7 @@ def index_doc_batch_with_handler(
index_pipeline_result = index_doc_batch(
chunker=chunker,
embedder=embedder,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
document_batch=document_batch,
index_attempt_metadata=index_attempt_metadata,
@ -450,6 +529,7 @@ def index_doc_batch(
document_batch: list[Document],
chunker: Chunker,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
@ -526,7 +606,23 @@ def index_doc_batch(
else ([], [])
)
chunk_content_scores = (
_get_aggregated_chunk_boost_factor(
chunks_with_embeddings, information_content_classification_model
)
if USE_INFORMATION_CONTENT_CLASSIFICATION
else [1.0] * len(chunks_with_embeddings)
)
updatable_ids = [doc.id for doc in ctx.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk.chunk_id,
document_id=chunk.source_document.id,
boost_score=score,
)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
@ -579,8 +675,9 @@ def index_doc_batch(
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk in chunks_with_embeddings
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
logger.debug(
@ -665,6 +762,11 @@ def index_doc_batch(
db_session=db_session,
)
# save the chunk boost components to postgres
update_chunk_boost_components__no_commit(
chunk_data=updatable_chunk_data, db_session=db_session
)
db_session.commit()
result = IndexingPipelineResult(
@ -680,6 +782,7 @@ def index_doc_batch(
def build_indexing_pipeline(
*,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
db_session: Session,
tenant_id: str,
@ -703,6 +806,7 @@ def build_indexing_pipeline(
index_doc_batch_with_handler,
chunker=chunker,
embedder=embedder,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=ignore_time_skip,
db_session=db_session,

View File

@ -83,13 +83,16 @@ class DocMetadataAwareIndexChunk(IndexChunk):
document_sets: all document sets the source document for this chunk is a part
of. This is used for filtering / personas.
boost: influences the ranking of this chunk at query time. Positive -> ranked higher,
negative -> ranked lower.
negative -> ranked lower. Not included in aggregated boost calculation
for legacy reasons.
aggregated_chunk_boost_factor: represents the aggregated chunk-level boost (currently: information content)
"""
tenant_id: str
access: "DocumentAccess"
document_sets: set[str]
boost: int
aggregated_chunk_boost_factor: float
@classmethod
def from_index_chunk(
@ -98,6 +101,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
boost: int,
aggregated_chunk_boost_factor: float,
tenant_id: str,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.model_dump()
@ -106,6 +110,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
boost=boost,
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
tenant_id=tenant_id,
)
@ -179,3 +184,9 @@ class IndexingSetting(EmbeddingModelDetail):
class MultipassConfig(BaseModel):
multipass_indexing: bool
enable_large_chunks: bool
class UpdatableChunkData(BaseModel):
chunk_id: int
document_id: str
boost_score: float

View File

@ -56,7 +56,9 @@ BEDROCK_PROVIDER_NAME = "bedrock"
# models
BEDROCK_MODEL_NAMES = [
model
for model in litellm.bedrock_models
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
# litellm has split them into two lists :(
for model in litellm.bedrock_models + litellm.bedrock_converse_models
if "/" not in model and "embed" not in model
][::-1]

View File

@ -29,6 +29,8 @@ from onyx.natural_language_processing.exceptions import (
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
@ -36,9 +38,11 @@ from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import InformationContentClassificationResponses
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
@ -377,6 +381,31 @@ class QueryAnalysisModel:
return response_model.is_keyword, response_model.keywords
class InformationContentClassificationModel:
def __init__(
self,
model_server_host: str = INDEXING_MODEL_SERVER_HOST,
model_server_port: int = INDEXING_MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.content_server_endpoint = (
model_server_url + "/custom/content-classification"
)
def predict(
self,
queries: list[str],
) -> list[ContentClassificationPrediction]:
response = requests.post(self.content_server_endpoint, json=queries)
response.raise_for_status()
model_responses = InformationContentClassificationResponses(
information_content_classifications=response.json()
)
return model_responses.information_content_classifications
class ConnectorClassificationModel:
def __init__(
self,

View File

@ -98,6 +98,7 @@ def _create_indexable_chunks(
boost=DEFAULT_BOOST,
large_chunk_id=None,
image_file_name=None,
aggregated_chunk_boost_factor=1.0,
)
chunks.append(chunk)

View File

@ -19,6 +19,9 @@ from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.server.onyx_api.models import DocMinimalInfo
from onyx.server.onyx_api.models import IngestionDocument
from onyx.server.onyx_api.models import IngestionResult
@ -102,8 +105,11 @@ def upsert_ingestion_doc(
search_settings=search_settings
)
information_content_classification_model = InformationContentClassificationModel()
indexing_pipeline = build_indexing_pipeline(
embedder=index_embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=curr_doc_index,
ignore_time_skip=True,
db_session=db_session,
@ -138,6 +144,7 @@ def upsert_ingestion_doc(
sec_ind_pipeline = build_indexing_pipeline(
embedder=new_index_embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=sec_doc_index,
ignore_time_skip=True,
db_session=db_session,

View File

@ -25,7 +25,7 @@ google-auth-oauthlib==1.0.0
httpcore==1.0.5
httpx[http2]==0.27.0
httpx-oauth==0.15.1
huggingface-hub==0.20.1
huggingface-hub==0.29.0
inflection==0.5.1
jira==3.5.1
jsonref==1.1.0
@ -38,7 +38,7 @@ langchainhub==0.1.21
langgraph==0.2.72
langgraph-checkpoint==2.0.13
langgraph-sdk==0.1.44
litellm==1.61.16
litellm==1.63.8
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45
@ -47,7 +47,7 @@ msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.61.0
openai==1.66.3
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5
@ -71,6 +71,7 @@ requests==2.32.2
requests-oauthlib==1.3.1
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
rfc3986==1.5.0
setfit==1.1.1
simple-salesforce==1.12.6
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
@ -78,7 +79,7 @@ starlette==0.36.3
supervisor==4.2.5
tiktoken==0.7.0
timeago==1.0.16
transformers==4.39.2
transformers==4.49.0
unstructured==0.15.1
unstructured-client==0.25.4
uvicorn==0.21.1

View File

@ -14,7 +14,7 @@ pytest-asyncio==0.22.0
pytest==7.4.4
reorder-python-imports==3.9.0
ruff==0.0.286
sentence-transformers==2.6.1
sentence-transformers==3.4.1
trafilatura==1.12.2
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13

View File

@ -7,9 +7,10 @@ openai==1.61.0
pydantic==2.8.2
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1
sentence-transformers==3.4.1
setfit==1.1.1
torch==2.2.0
transformers==4.39.2
transformers==4.49.0
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.61.16

View File

@ -161,17 +161,21 @@ overview_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/overview",
title=overview_title,
content=overview,
title_embedding=model.encode(f"search_document: {overview_title}"),
content_embedding=model.encode(f"search_document: {overview_title}\n{overview}"),
title_embedding=list(model.encode(f"search_document: {overview_title}")),
content_embedding=list(
model.encode(f"search_document: {overview_title}\n{overview}")
),
)
enterprise_search_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/enterprise_search",
title=enterprise_search_title,
content=enterprise_search_1,
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
content_embedding=model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
content_embedding=list(
model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
)
),
)
@ -179,9 +183,11 @@ enterprise_search_doc_2 = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/enterprise_search",
title=enterprise_search_title,
content=enterprise_search_2,
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
content_embedding=model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
content_embedding=list(
model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
)
),
chunk_ind=1,
)
@ -190,9 +196,9 @@ ai_platform_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/ai_platform",
title=ai_platform_title,
content=ai_platform,
title_embedding=model.encode(f"search_document: {ai_platform_title}"),
content_embedding=model.encode(
f"search_document: {ai_platform_title}\n{ai_platform}"
title_embedding=list(model.encode(f"search_document: {ai_platform_title}")),
content_embedding=list(
model.encode(f"search_document: {ai_platform_title}\n{ai_platform}")
),
)
@ -200,9 +206,9 @@ customer_support_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/support",
title=customer_support_title,
content=customer_support,
title_embedding=model.encode(f"search_document: {customer_support_title}"),
content_embedding=model.encode(
f"search_document: {customer_support_title}\n{customer_support}"
title_embedding=list(model.encode(f"search_document: {customer_support_title}")),
content_embedding=list(
model.encode(f"search_document: {customer_support_title}\n{customer_support}")
),
)
@ -210,17 +216,17 @@ sales_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/sales",
title=sales_title,
content=sales,
title_embedding=model.encode(f"search_document: {sales_title}"),
content_embedding=model.encode(f"search_document: {sales_title}\n{sales}"),
title_embedding=list(model.encode(f"search_document: {sales_title}")),
content_embedding=list(model.encode(f"search_document: {sales_title}\n{sales}")),
)
operations_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/operations",
title=operations_title,
content=operations,
title_embedding=model.encode(f"search_document: {operations_title}"),
content_embedding=model.encode(
f"search_document: {operations_title}\n{operations}"
title_embedding=list(model.encode(f"search_document: {operations_title}")),
content_embedding=list(
model.encode(f"search_document: {operations_title}\n{operations}")
),
)

View File

@ -99,6 +99,7 @@ def generate_dummy_chunk(
),
document_sets={document_set for document_set in document_set_names},
boost=random.randint(-1, 1),
aggregated_chunk_boost_factor=random.random(),
tenant_id=POSTGRES_DEFAULT_SCHEMA,
)

View File

@ -23,9 +23,11 @@ INDEXING_MODEL_SERVER_PORT = int(
# Onyx custom Deep Learning Models
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
INTENT_MODEL_TAG = "v1.0.3"
INTENT_MODEL_VERSION = "onyx-dot-app/hybrid-intent-token-classifier"
# INTENT_MODEL_TAG = "v1.0.3"
INTENT_MODEL_TAG: str | None = None
INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model"
INFORMATION_CONTENT_MODEL_TAG: str | None = None
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512
@ -277,3 +279,20 @@ SUPPORTED_EMBEDDING_MODELS = [
index_name="danswer_chunk_intfloat_multilingual_e5_small",
),
]
# Maximum (least severe) downgrade factor for chunks above the cutoff
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX") or 1.0
)
# Minimum (most severe) downgrade factor for short chunks below the cutoff if no content
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.7
)
# Temperature for the information content classification model
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
)
# Cutoff below which we start using the information content classification model
# (cutoff length number itself is still considered 'short'))
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10
)

View File

@ -4,6 +4,7 @@ from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
Embedding = list[float]
@ -73,7 +74,20 @@ class IntentResponse(BaseModel):
keywords: list[str]
class InformationContentClassificationRequests(BaseModel):
queries: list[str]
class SupportedEmbeddingModel(BaseModel):
name: str
dim: int
index_name: str
class ContentClassificationPrediction(BaseModel):
predicted_label: int
content_boost_factor: float
class InformationContentClassificationResponses(BaseModel):
information_content_classifications: list[ContentClassificationPrediction]

View File

@ -0,0 +1,145 @@
from typing import Any
from unittest.mock import Mock
from unittest.mock import patch
import numpy as np
import numpy.typing as npt
import pytest
from model_server.custom_models import run_content_classification_inference
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.model_server_models import ContentClassificationPrediction
@pytest.fixture
def mock_content_model() -> Mock:
model = Mock()
# Create actual numpy arrays for the mock returns
predict_output = np.array(
[1, 0] * 50, dtype=np.int64
) # Pre-allocate enough elements
proba_output = np.array(
[[0.3, 0.7], [0.7, 0.3]] * 50, dtype=np.float64
) # Pre-allocate enough elements
# Create a mock tensor that has a numpy method and supports indexing
class MockTensor:
def __init__(self, value: npt.NDArray[Any]) -> None:
self.value = value
def numpy(self) -> npt.NDArray[Any]:
return self.value
def __getitem__(self, idx: Any) -> Any:
result = self.value[idx]
# Wrap scalar values back in MockTensor
if isinstance(result, (np.float64, np.int64)):
return MockTensor(np.array([result]))
return MockTensor(result)
# Mock the direct call to return a MockTensor for each input
def model_call(inputs: list[str]) -> list[MockTensor]:
batch_size = len(inputs)
return [MockTensor(predict_output[i : i + 1]) for i in range(batch_size)]
model.side_effect = model_call
# Mock predict_proba to return MockTensor-wrapped numpy array
def predict_proba_call(x: list[str]) -> MockTensor:
batch_size = len(x)
return MockTensor(proba_output[:batch_size])
model.predict_proba.side_effect = predict_proba_call
return model
@patch("model_server.custom_models.get_local_information_content_model")
def test_run_content_classification_inference(
mock_get_model: Mock,
mock_content_model: Mock,
) -> None:
"""
Test the content classification inference function.
Verifies that the function correctly processes text inputs and returns appropriate predictions.
"""
# Setup
mock_get_model.return_value = mock_content_model
test_inputs = [
"Imagine a short text with content",
"Imagine a short text without content",
"x "
* (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1
), # Long input that exceeds maximal length for when the model should be applied
"", # Empty input
]
# Execute
results = run_content_classification_inference(test_inputs)
# Assert
assert len(results) == len(test_inputs)
assert all(isinstance(r, ContentClassificationPrediction) for r in results)
# Check each prediction has expected attributes and ranges
for result_num, result in enumerate(results):
assert hasattr(result, "predicted_label")
assert hasattr(result, "content_boost_factor")
assert isinstance(result.predicted_label, int)
assert isinstance(result.content_boost_factor, float)
assert (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
<= result.content_boost_factor
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
)
if result_num == 2:
assert (
result.content_boost_factor
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
)
assert result.predicted_label == 1
elif result_num == 3:
assert (
result.content_boost_factor
== INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
)
assert result.predicted_label == 0
# Verify model handling of long inputs
mock_content_model.predict_proba.reset_mock()
long_input = ["x " * 1000] # Definitely exceeds MAX_LENGTH
results = run_content_classification_inference(long_input)
assert len(results) == 1
assert (
mock_content_model.predict_proba.call_count == 0
) # Should skip model call for too-long input
@patch("model_server.custom_models.get_local_information_content_model")
def test_batch_processing(
mock_get_model: Mock,
mock_content_model: Mock,
) -> None:
"""
Test that the function correctly handles batch processing of inputs.
"""
# Setup
mock_get_model.return_value = mock_content_model
# Create test input larger than batch size
test_inputs = [f"Test input {i}" for i in range(40)] # > BATCH_SIZE (32)
# Execute
results = run_content_classification_inference(test_inputs)
# Assert
assert len(results) == 40
# Verify batching occurred (should have called predict_proba twice)
assert mock_content_model.predict_proba.call_count == 2

View File

@ -1,12 +1,24 @@
from typing import cast
from typing import List
from unittest.mock import Mock
import pytest
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_pipeline import _get_aggregated_chunk_boost_factor
from onyx.indexing.indexing_pipeline import filter_documents
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import IndexChunk
from onyx.natural_language_processing.search_nlp_models import (
ContentClassificationPrediction,
)
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
def create_test_document(
@ -123,3 +135,117 @@ def test_filter_documents_multiple_documents() -> None:
def test_filter_documents_empty_batch() -> None:
result = filter_documents([])
assert len(result) == 0
# Tests for get_aggregated_boost_factor
def create_test_chunk(
content: str, chunk_id: int = 0, doc_id: str = "test_doc"
) -> IndexChunk:
doc = Document(
id=doc_id,
semantic_identifier="test doc",
sections=[],
source=DocumentSource.FILE,
metadata={},
)
return IndexChunk(
chunk_id=chunk_id,
content=content,
source_document=doc,
blurb=content[:50], # First 50 chars as blurb
source_links={0: "test_link"},
section_continuation=False,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_id=None,
large_chunk_reference_ids=[],
embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]),
title_embedding=None,
image_file_name=None,
)
def test_get_aggregated_boost_factor() -> None:
# Create test chunks - mix of short and long content
chunks = [
create_test_chunk("Short content", 0),
create_test_chunk(
"Long " * (INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH + 1), 1
),
create_test_chunk("Another short chunk", 2),
]
# Mock the classification model
mock_model = Mock()
mock_model.predict.return_value = [
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8),
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.9),
]
# Execute the function
boost_scores = _get_aggregated_chunk_boost_factor(
chunks=chunks, information_content_classification_model=mock_model
)
# Assertions
assert len(boost_scores) == 3
# Check that long content got default boost
assert boost_scores[1] == 1.0
# Check that short content got predicted boosts
assert boost_scores[0] == 0.8
assert boost_scores[2] == 0.9
# Verify model was only called once with the short chunks
mock_model.predict.assert_called_once()
assert len(mock_model.predict.call_args[0][0]) == 2
def test_get_aggregated_boost_factorilure() -> None:
chunks = [
create_test_chunk("Short content 1", 0),
create_test_chunk("Short content 2", 1),
]
# Mock model to fail on batch prediction but succeed on individual predictions
mock_model = Mock()
mock_model.predict.side_effect = [
Exception("Batch prediction failed"), # First call fails
[
ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.7)
], # Individual calls succeed
[ContentClassificationPrediction(predicted_label=1, content_boost_factor=0.8)],
]
# Execute
boost_scores = _get_aggregated_chunk_boost_factor(
chunks=chunks, information_content_classification_model=mock_model
)
# Assertions
assert len(boost_scores) == 2
assert boost_scores == [0.7, 0.8]
def test_get_aggregated_boost_factor_individual_failure() -> None:
chunks = [
create_test_chunk("Short content", 0),
create_test_chunk("Short content", 1),
]
# Mock model to fail on both batch and individual prediction
mock_model = Mock()
mock_model.predict.side_effect = Exception("Prediction failed")
# Execute and verify it raises an exception
with pytest.raises(Exception) as exc_info:
_get_aggregated_chunk_boost_factor(
chunks=chunks, information_content_classification_model=mock_model
)
assert "Failed to predict content classification for chunk" in str(exc_info.value)

View File

@ -284,7 +284,9 @@ export function LLMProviderUpdateForm({
subtext="The model to use by default for this provider unless otherwise specified."
label="Default Model"
options={llmProviderDescriptor.llm_names.map((name) => ({
name: getDisplayNameForModel(name),
// don't clean up names here to give admins descriptive names / handle duplicates
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
name: name,
value: name,
}))}
maxHeight="max-h-56"
@ -314,7 +316,9 @@ export function LLMProviderUpdateForm({
the Default Model configured above.`}
label="[Optional] Fast Model"
options={llmProviderDescriptor.llm_names.map((name) => ({
name: getDisplayNameForModel(name),
// don't clean up names here to give admins descriptive names / handle duplicates
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
name: name,
value: name,
}))}
includeDefault
@ -355,7 +359,9 @@ export function LLMProviderUpdateForm({
options={llmProviderDescriptor.llm_names.map(
(name) => ({
value: name,
label: getDisplayNameForModel(name),
// don't clean up names here to give admins descriptive names / handle duplicates
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
label: name,
})
)}
onChange={(selected) =>

View File

@ -16,7 +16,15 @@ export function createValidationSchema(json_values: Record<string, any>) {
const displayName = getDisplayNameForCredentialKey(key);
if (json_values[key] === null) {
if (typeof json_values[key] === "boolean") {
// Ensure false is considered valid
schemaFields[key] = Yup.boolean()
.nullable()
.default(false)
.transform((value, originalValue) =>
originalValue === undefined ? false : value
);
} else if (json_values[key] === null) {
// Field is optional:
schemaFields[key] = Yup.string()
.trim()