danswer/backend/onyx/db/connector_credential_pair.py
evan-danswer 6ce810e957
faster indexing status at scale plus minor cleanups (#4081)
* faster indexing status at scale plus minor cleanups

* mypy

* address chris comments

* remove extra prints
2025-02-25 21:22:26 +00:00

637 lines
20 KiB
Python

from datetime import datetime
from typing import TypeVarTuple
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
R = TypeVarTuple("R")
def _add_user_filters(
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
) -> Select[tuple[*R]]:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
stmt = stmt.distinct()
UG__CCpair = aliased(UserGroup__ConnectorCredentialPair)
User__UG = aliased(User__UserGroup)
"""
Here we select cc_pairs by relation:
User -> User__UserGroup -> UserGroup__ConnectorCredentialPair ->
ConnectorCredentialPair
"""
stmt = stmt.outerjoin(UG__CCpair).outerjoin(
User__UG,
User__UG.user_group_id == UG__CCpair.user_group_id,
)
"""
Filter cc_pairs by:
- if the user is in the user_group that owns the cc_pair
- if the user is not a global_curator, they must also have a curator relationship
to the user_group
- if editing is being done, we also filter out cc_pairs that are owned by groups
that the user isn't a curator for
- if we are not editing, we show all cc_pairs in the groups the user is a curator
for (as well as public cc_pairs)
"""
# If user is None, this is an anonymous user and we should only show public cc_pairs
if user is None:
where_clause = ConnectorCredentialPair.access_type == AccessType.PUBLIC
return stmt.where(where_clause)
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
if get_editable:
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
if user.role == UserRole.CURATOR:
user_groups = user_groups.where(
User__UserGroup.is_curator == True # noqa: E712
)
where_clause &= (
~exists()
.where(UG__CCpair.cc_pair_id == ConnectorCredentialPair.id)
.where(~UG__CCpair.user_group_id.in_(user_groups))
.correlate(ConnectorCredentialPair)
)
where_clause |= ConnectorCredentialPair.creator_id == user.id
else:
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
where_clause |= ConnectorCredentialPair.access_type == AccessType.SYNC
return stmt.where(where_clause)
def get_connector_credential_pairs_for_user(
db_session: Session,
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
if eager_load_user:
assert (
eager_load_credential
), "eager_load_credential must be True if eager_load_user is True"
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
if eager_load_user:
load_opts = load_opts.joinedload(Credential.user)
stmt = stmt.options(load_opts)
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).unique().all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_connector_credential_pairs_for_user_parallel(
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_connector_credential_pairs_for_user(
db_session,
user,
get_editable,
ids,
eager_load_connector,
eager_load_credential,
eager_load_user,
)
def get_connector_credential_pairs(
db_session: Session,
ids: list[int] | None = None,
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
def add_deletion_failure_message(
db_session: Session,
cc_pair_id: int,
failure_message: str,
) -> None:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
return
cc_pair.deletion_failure_message = failure_message
db_session.commit()
def get_cc_pair_groups_for_ids(
db_session: Session,
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
stmt = select(UserGroup__ConnectorCredentialPair).distinct()
stmt = stmt.outerjoin(
ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
)
stmt = stmt.where(UserGroup__ConnectorCredentialPair.cc_pair_id.in_(cc_pair_ids))
return list(db_session.scalars(stmt).all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_cc_pair_groups_for_ids_parallel(
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
credential_id: int,
user: User | None,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
stmt = _add_user_filters(stmt, user, get_editable)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def get_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def get_connector_credential_pair_from_id_for_user(
cc_pair_id: int,
db_session: Session,
user: User | None,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = _add_user_filters(stmt, user, get_editable)
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
eager_load_credential: bool = False,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
if eager_load_credential:
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def get_last_successful_attempt_time(
connector_id: int,
credential_id: int,
earliest_index: float,
search_settings: SearchSettings,
db_session: Session,
) -> float:
"""Gets the timestamp of the last successful index run stored in
the CC Pair row in the database"""
if search_settings.status == IndexModelStatus.PRESENT:
connector_credential_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if (
connector_credential_pair is None
or connector_credential_pair.last_successful_index_time is None
):
return earliest_index
return connector_credential_pair.last_successful_index_time.timestamp()
# For Secondary Index we don't keep track of the latest success, so have to calculate it live
attempt = (
db_session.query(IndexAttempt)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
IndexAttempt.search_settings_id == search_settings.id,
IndexAttempt.status == IndexingStatus.SUCCESS,
)
.order_by(IndexAttempt.time_started.desc())
.first()
)
if not attempt or not attempt.time_started:
return earliest_index
return attempt.time_started.timestamp()
"""Updates"""
def _update_connector_credential_pair(
db_session: Session,
cc_pair: ConnectorCredentialPair,
status: ConnectorCredentialPairStatus | None = None,
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
# simply don't update last_successful_index_time if run_dt is not specified
# at worst, this would result in re-indexing documents that were already indexed
if run_dt is not None:
cc_pair.last_successful_index_time = run_dt
if net_docs is not None:
cc_pair.total_docs_indexed += net_docs
if status is not None:
cc_pair.status = status
db_session.commit()
def update_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
status: ConnectorCredentialPairStatus | None = None,
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
logger.warning(
f"Attempted to update pair for Connector Credential Pair '{cc_pair_id}'"
f" but it does not exist"
)
return
_update_connector_credential_pair(
db_session=db_session,
cc_pair=cc_pair,
status=status,
net_docs=net_docs,
run_dt=run_dt,
)
def update_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
status: ConnectorCredentialPairStatus | None = None,
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
logger.warning(
f"Attempted to update pair for connector id {connector_id} "
f"and credential id {credential_id}"
)
return
_update_connector_credential_pair(
db_session=db_session,
cc_pair=cc_pair,
status=status,
net_docs=net_docs,
run_dt=run_dt,
)
def delete_connector_credential_pair__no_commit(
db_session: Session,
connector_id: int,
credential_id: int,
) -> None:
stmt = delete(ConnectorCredentialPair).where(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
)
db_session.execute(stmt)
def associate_default_cc_pair(db_session: Session) -> None:
existing_association = (
db_session.query(ConnectorCredentialPair)
.filter(
ConnectorCredentialPair.connector_id == 0,
ConnectorCredentialPair.credential_id == 0,
)
.one_or_none()
)
if existing_association is not None:
return
# DefaultCCPair has id 1 since it is the first CC pair created
# It is DEFAULT_CC_PAIR_ID, but can't set it explicitly because it messed with the
# auto-incrementing id
association = ConnectorCredentialPair(
connector_id=0,
credential_id=0,
access_type=AccessType.PUBLIC,
name="DefaultCCPair",
status=ConnectorCredentialPairStatus.ACTIVE,
)
db_session.add(association)
db_session.commit()
def _relate_groups_to_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
user_group_ids: list[int] | None = None,
) -> None:
if not user_group_ids:
return
for group_id in user_group_ids:
db_session.add(
UserGroup__ConnectorCredentialPair(
user_group_id=group_id, cc_pair_id=cc_pair_id
)
)
def add_credential_to_connector(
db_session: Session,
user: User | None,
connector_id: int,
credential_id: int,
cc_pair_name: str | None,
access_type: AccessType,
groups: list[int] | None,
auto_sync_options: dict | None = None,
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
last_successful_index_time: datetime | None = None,
seeding_flow: bool = False,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
# If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user
if seeding_flow:
credential = fetch_credential_by_id(
credential_id=credential_id,
db_session=db_session,
)
else:
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
if access_type == AccessType.SYNC:
if not fetch_ee_implementation_or_noop(
"onyx.external_permissions.sync_params",
"check_if_valid_sync_source",
noop_return_value=True,
)(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
)
if credential is None:
error_msg = (
f"Credential {credential_id} does not exist or does not belong to user"
)
logger.error(error_msg)
raise HTTPException(
status_code=401,
detail=error_msg,
)
existing_association = (
db_session.query(ConnectorCredentialPair)
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
)
.one_or_none()
)
if existing_association is not None:
return StatusResponse(
success=False,
message=f"Connector {connector_id} already has Credential {credential_id}",
data=connector_id,
)
association = ConnectorCredentialPair(
creator_id=user.id if user else None,
connector_id=connector_id,
credential_id=credential_id,
name=cc_pair_name,
status=initial_status,
access_type=access_type,
auto_sync_options=auto_sync_options,
last_successful_index_time=last_successful_index_time,
)
db_session.add(association)
db_session.flush() # make sure the association has an id
db_session.refresh(association)
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
db_session.commit()
return StatusResponse(
success=True,
message=f"Creating new association between Connector {connector_id} and Credential {credential_id}",
data=association.id,
)
def remove_credential_from_connector(
connector_id: int,
credential_id: int,
user: User | None,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
if credential is None:
raise HTTPException(
status_code=404,
detail="Credential does not exist or does not belong to user",
)
association = get_connector_credential_pair_for_user(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
user=user,
get_editable=True,
)
if association is not None:
fetch_ee_implementation_or_noop(
"onyx.db.external_perm",
"delete_user__ext_group_for_cc_pair__no_commit",
)(
db_session=db_session,
cc_pair_id=association.id,
)
db_session.delete(association)
db_session.commit()
return StatusResponse(
success=True,
message=f"Credential {credential_id} removed from Connector",
data=connector_id,
)
return StatusResponse(
success=False,
message=f"Connector already does not have Credential {credential_id}",
data=connector_id,
)
def fetch_connector_credential_pairs(
db_session: Session,
) -> list[ConnectorCredentialPair]:
return db_session.query(ConnectorCredentialPair).all()
def resync_cc_pair(
cc_pair: ConnectorCredentialPair,
db_session: Session,
) -> None:
def find_latest_index_attempt(
connector_id: int,
credential_id: int,
only_include_success: bool,
db_session: Session,
) -> IndexAttempt | None:
query = (
db_session.query(IndexAttempt)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
SearchSettings.status == IndexModelStatus.PRESENT,
)
)
if only_include_success:
query = query.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
latest_index_attempt = query.order_by(desc(IndexAttempt.time_started)).first()
return latest_index_attempt
last_success = find_latest_index_attempt(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
only_include_success=True,
db_session=db_session,
)
cc_pair.last_successful_index_time = (
last_success.time_started if last_success else None
)
db_session.commit()