faster indexing status at scale plus minor cleanups (#4081)

* faster indexing status at scale plus minor cleanups

* mypy

* address chris comments

* remove extra prints
This commit is contained in:
evan-danswer 2025-02-25 13:22:26 -08:00 committed by GitHub
parent 07b0b57b31
commit 6ce810e957
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 249 additions and 112 deletions

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import TypeVarTuple
from fastapi import HTTPException
from sqlalchemy import delete
@ -8,15 +9,18 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
@ -31,10 +35,12 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
R = TypeVarTuple("R")
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
) -> Select[tuple[*R]]:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@ -98,17 +104,52 @@ def get_connector_credential_pairs_for_user(
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
if eager_load_user:
assert (
eager_load_credential
), "eager_load_credential must be True if eager_load_user is True"
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
if eager_load_user:
load_opts = load_opts.joinedload(Credential.user)
stmt = stmt.options(load_opts)
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
return list(db_session.scalars(stmt).unique().all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_connector_credential_pairs_for_user_parallel(
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_connector_credential_pairs_for_user(
db_session,
user,
get_editable,
ids,
eager_load_connector,
eager_load_credential,
eager_load_user,
)
def get_connector_credential_pairs(
@ -151,6 +192,16 @@ def get_cc_pair_groups_for_ids(
return list(db_session.scalars(stmt).all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_cc_pair_groups_for_ids_parallel(
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,

View File

@ -24,6 +24,7 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
@ -229,12 +230,12 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
# Prepare a list of (connector_id, credential_id) tuples
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
stmt = (
select(
@ -260,6 +261,16 @@ def get_document_counts_for_cc_pairs(
return db_session.execute(stmt).all() # type: ignore
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_document_counts_for_cc_pairs_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
) -> Sequence[tuple[int, int, int]]:
with get_session_context_manager() as db_session:
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
def get_access_info_for_document(
db_session: Session,
document_id: str,

View File

@ -218,6 +218,7 @@ class SqlEngine:
final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:

View File

@ -2,6 +2,7 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import TypeVarTuple
from sqlalchemy import and_
from sqlalchemy import delete
@ -9,9 +10,13 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
@ -368,19 +373,33 @@ def get_latest_index_attempts_by_status(
return db_session.execute(stmt).scalars().all()
T = TypeVarTuple("T")
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
return stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
def get_latest_index_attempts(
secondary_index: bool,
db_session: Session,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
ids_stmt = select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
if secondary_index:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
else:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
ids_stmt = ids_stmt.where(SearchSettings.status == status)
if only_finished:
ids_stmt = _add_only_finished_clause(ids_stmt)
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
@ -395,7 +414,53 @@ def get_latest_index_attempts(
.where(IndexAttempt.id == ids_subquery.c.max_id)
)
return db_session.execute(stmt).scalars().all()
if only_finished:
stmt = _add_only_finished_clause(stmt)
if eager_load_cc_pair:
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
return db_session.execute(stmt).scalars().unique().all()
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_latest_index_attempts_parallel(
secondary_index: bool,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,
eager_load_cc_pair,
only_finished,
)
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
)
if only_finished:
stmt = _add_only_finished_clause(stmt)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
def count_index_attempts_for_connector(
@ -453,37 +518,12 @@ def get_paginated_index_attempts_for_cc_pair_id(
# Apply pagination
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.execute(stmt).scalars().all())
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
stmt = stmt.options(
contains_eager(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
if only_finished:
stmt = stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
return list(db_session.execute(stmt).scalars().unique().all())
def get_index_attempts_for_cc_pair(

View File

@ -93,10 +93,7 @@ class RedisConnectorIndex:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorIndexPayload | None:
@ -106,9 +103,7 @@ class RedisConnectorIndex:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
return payload
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
def set_fence(
self,
@ -123,10 +118,7 @@ class RedisConnectorIndex:
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def terminating(self, celery_task_id: str) -> bool:
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
return True
return False
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
@ -146,10 +138,7 @@ class RedisConnectorIndex:
def watchdog_signaled(self) -> bool:
"""Check the state of the watchdog."""
if self.redis.exists(self.watchdog_key):
return True
return False
return bool(self.redis.exists(self.watchdog_key))
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@ -160,10 +149,7 @@ class RedisConnectorIndex:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
def set_connector_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@ -180,10 +166,7 @@ class RedisConnectorIndex:
return False
def generator_locked(self) -> bool:
if self.redis.exists(self.generator_lock_key):
return True
return False
return bool(self.redis.exists(self.generator_lock_key))
def set_generator_complete(self, payload: int | None) -> None:
if not payload:

View File

@ -123,15 +123,15 @@ def get_cc_pair_full_info(
)
is_editable_for_current_user = editable_cc_pair is not None
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
document_count_info_list = list(
get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=[cc_pair_identifier],
cc_pairs=[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
],
)
)
documents_indexed = (

View File

@ -72,25 +72,31 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector import update_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids_parallel
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_for_user_parallel,
)
from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import cleanup_google_drive_credentials
from onyx.db.credentials import create_credential
from onyx.db.credentials import delete_service_account_credentials
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
from onyx.db.document import get_document_counts_for_cc_pairs
from onyx.db.document import get_document_counts_for_cc_pairs_parallel
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import IndexingMode
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import get_latest_index_attempts
from onyx.db.index_attempt import get_latest_index_attempts_by_status
from onyx.db.index_attempt import get_latest_index_attempts_parallel
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.file_processing.extract_file_text import convert_docx_to_txt
@ -119,8 +125,8 @@ from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@ -578,6 +584,8 @@ def get_connector_status(
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
eager_load_connector=True,
eager_load_credential=True,
)
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
@ -632,23 +640,35 @@ def get_connector_indexing_status(
# Additional checks are done to make sure the connector and credential still exist.
# TODO: make this one query ... possibly eager load or wrap in a read transaction
# to avoid the complexity of trying to error check throughout the function
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
)
cc_pair_identifiers = [
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
for cc_pair in cc_pairs
]
latest_index_attempts = get_latest_index_attempts(
secondary_index=secondary_index,
db_session=db_session,
# see https://stackoverflow.com/questions/75758327/
# sqlalchemy-method-connection-for-bind-is-already-in-progress
# for why we can't pass in the current db_session to these functions
(
cc_pairs,
latest_index_attempts,
latest_finished_index_attempts,
) = run_functions_tuples_in_parallel(
[
(
# Gets the connector/credential pairs for the user
get_connector_credential_pairs_for_user_parallel,
(user, get_editable, None, True, True, True),
),
(
# Gets the most recent index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, False),
),
(
# Gets the most recent FINISHED index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, True),
),
]
)
cc_pairs = cast(list[ConnectorCredentialPair], cc_pairs)
latest_index_attempts = cast(list[IndexAttempt], latest_index_attempts)
cc_pair_to_latest_index_attempt = {
(
@ -658,31 +678,60 @@ def get_connector_indexing_status(
for index_attempt in latest_index_attempts
}
document_count_info = get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=cc_pair_identifiers,
cc_pair_to_latest_finished_index_attempt = {
(
index_attempt.connector_credential_pair.connector_id,
index_attempt.connector_credential_pair.credential_id,
): index_attempt
for index_attempt in latest_finished_index_attempts
}
document_count_info, group_cc_pair_relationships = run_functions_tuples_in_parallel(
[
(
get_document_counts_for_cc_pairs_parallel,
(
[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
for cc_pair in cc_pairs
],
),
),
(
get_cc_pair_groups_for_ids_parallel,
([cc_pair.id for cc_pair in cc_pairs],),
),
]
)
document_count_info = cast(list[tuple[int, int, int]], document_count_info)
group_cc_pair_relationships = cast(
list[UserGroup__ConnectorCredentialPair], group_cc_pair_relationships
)
cc_pair_to_document_cnt = {
(connector_id, credential_id): cnt
for connector_id, credential_id, cnt in document_count_info
}
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
db_session=db_session,
cc_pair_ids=[cc_pair.id for cc_pair in cc_pairs],
)
group_cc_pair_relationships_dict: dict[int, list[int]] = {}
for relationship in group_cc_pair_relationships:
group_cc_pair_relationships_dict.setdefault(relationship.cc_pair_id, []).append(
relationship.user_group_id
)
search_settings: SearchSettings | None = None
if not secondary_index:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_secondary_search_settings(db_session)
connector_to_cc_pair_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_cc_pair_ids.setdefault(cc_pair.connector_id, []).append(cc_pair.id)
get_search_settings = (
get_secondary_search_settings
if secondary_index
else get_current_search_settings
)
search_settings = get_search_settings(db_session)
for cc_pair in cc_pairs:
# TODO remove this to enable ingestion API
if cc_pair.name == "DefaultCCPair":
@ -705,11 +754,8 @@ def get_connector_indexing_status(
(connector.id, credential.id)
)
latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id(
db_session=db_session,
connector_credential_pair_id=cc_pair.id,
secondary_index=secondary_index,
only_finished=True,
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
(connector.id, credential.id)
)
indexing_statuses.append(
@ -718,7 +764,9 @@ def get_connector_indexing_status(
name=cc_pair.name,
in_progress=in_progress,
cc_pair_status=cc_pair.status,
connector=ConnectorSnapshot.from_connector_db_model(connector),
connector=ConnectorSnapshot.from_connector_db_model(
connector, connector_to_cc_pair_ids.get(connector.id, [])
),
credential=CredentialSnapshot.from_credential_db_model(credential),
access_type=cc_pair.access_type,
owner=credential.user.email if credential.user else "",

View File

@ -83,7 +83,9 @@ class ConnectorSnapshot(ConnectorBase):
source: DocumentSource
@classmethod
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
def from_connector_db_model(
cls, connector: Connector, credential_ids: list[int] | None = None
) -> "ConnectorSnapshot":
return ConnectorSnapshot(
id=connector.id,
name=connector.name,
@ -92,9 +94,10 @@ class ConnectorSnapshot(ConnectorBase):
connector_specific_config=connector.connector_specific_config,
refresh_freq=connector.refresh_freq,
prune_freq=connector.prune_freq,
credential_ids=[
association.credential.id for association in connector.credentials
],
credential_ids=(
credential_ids
or [association.credential.id for association in connector.credentials]
),
indexing_start=connector.indexing_start,
time_created=connector.time_created,
time_updated=connector.time_updated,