danswer/backend/onyx/db/connector.py
2024-12-13 09:56:10 -08:00

337 lines
11 KiB
Python

from datetime import datetime
from datetime import timezone
from typing import cast
from sqlalchemy import and_
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DEFAULT_PRUNING_FREQ
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.db.enums import IndexingMode
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.server.documents.models import ConnectorBase
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
logger = setup_logger()
def check_connectors_exist(db_session: Session) -> bool:
# Connector 0 is created on server startup as a default for ingestion
# it will always exist and we don't need to count it for this
stmt = select(exists(Connector).where(Connector.id > 0))
result = db_session.execute(stmt)
return result.scalar() or False
def fetch_connectors(
db_session: Session,
sources: list[DocumentSource] | None = None,
input_types: list[InputType] | None = None,
) -> list[Connector]:
stmt = select(Connector)
if sources is not None:
stmt = stmt.where(Connector.source.in_(sources))
if input_types is not None:
stmt = stmt.where(Connector.input_type.in_(input_types))
results = db_session.scalars(stmt)
return list(results.all())
def connector_by_name_source_exists(
connector_name: str, source: DocumentSource, db_session: Session
) -> bool:
stmt = select(Connector).where(
Connector.name == connector_name, Connector.source == source
)
result = db_session.execute(stmt)
connector = result.scalar_one_or_none()
return connector is not None
def fetch_connector_by_id(connector_id: int, db_session: Session) -> Connector | None:
stmt = select(Connector).where(Connector.id == connector_id)
result = db_session.execute(stmt)
connector = result.scalar_one_or_none()
return connector
def fetch_ingestion_connector_by_name(
connector_name: str, db_session: Session
) -> Connector | None:
stmt = (
select(Connector)
.where(Connector.name == connector_name)
.where(Connector.source == DocumentSource.INGESTION_API)
)
result = db_session.execute(stmt)
connector = result.scalar_one_or_none()
return connector
def create_connector(
db_session: Session,
connector_data: ConnectorBase,
) -> ObjectCreationIdResponse:
if connector_by_name_source_exists(
connector_data.name, connector_data.source, db_session
):
raise ValueError(
"Connector by this name already exists, duplicate naming not allowed."
)
connector = Connector(
name=connector_data.name,
source=connector_data.source,
input_type=connector_data.input_type,
connector_specific_config=connector_data.connector_specific_config,
refresh_freq=connector_data.refresh_freq,
indexing_start=connector_data.indexing_start,
prune_freq=connector_data.prune_freq,
)
db_session.add(connector)
db_session.commit()
return ObjectCreationIdResponse(id=connector.id)
def update_connector(
connector_id: int,
connector_data: ConnectorBase,
db_session: Session,
) -> Connector | None:
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
return None
if connector_data.name != connector.name and connector_by_name_source_exists(
connector_data.name, connector_data.source, db_session
):
raise ValueError(
"Connector by this name already exists, duplicate naming not allowed."
)
connector.name = connector_data.name
connector.source = connector_data.source
connector.input_type = connector_data.input_type
connector.connector_specific_config = connector_data.connector_specific_config
connector.refresh_freq = connector_data.refresh_freq
connector.prune_freq = (
connector_data.prune_freq
if connector_data.prune_freq is not None
else DEFAULT_PRUNING_FREQ
)
db_session.commit()
return connector
def delete_connector(
db_session: Session,
connector_id: int,
) -> StatusResponse[int]:
"""Only used in special cases (e.g. a connector is in a bad state and we need to delete it).
Be VERY careful using this, as it could lead to a bad state if not used correctly.
"""
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
return StatusResponse(
success=True, message="Connector was already deleted", data=connector_id
)
db_session.delete(connector)
return StatusResponse(
success=True, message="Connector deleted successfully", data=connector_id
)
def get_connector_credential_ids(
connector_id: int,
db_session: Session,
) -> list[int]:
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
raise ValueError(f"Connector by id {connector_id} does not exist")
return [association.credential.id for association in connector.credentials]
def fetch_latest_index_attempt_by_connector(
db_session: Session,
source: DocumentSource | None = None,
) -> list[IndexAttempt]:
latest_index_attempts: list[IndexAttempt] = []
if source:
connectors = fetch_connectors(db_session, sources=[source])
else:
connectors = fetch_connectors(db_session)
if not connectors:
return []
for connector in connectors:
latest_index_attempt = (
db_session.query(IndexAttempt)
.join(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.connector_id == connector.id)
.order_by(IndexAttempt.time_updated.desc())
.first()
)
if latest_index_attempt is not None:
latest_index_attempts.append(latest_index_attempt)
return latest_index_attempts
def fetch_latest_index_attempts_by_status(
db_session: Session,
) -> list[IndexAttempt]:
subquery = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
IndexAttempt.status,
func.max(IndexAttempt.time_updated).label("time_updated"),
)
.group_by(IndexAttempt.connector_credential_pair_id)
.group_by(IndexAttempt.status)
.subquery()
)
alias = aliased(IndexAttempt, subquery)
query = db_session.query(IndexAttempt).join(
alias,
and_(
IndexAttempt.connector_credential_pair_id
== alias.connector_credential_pair_id,
IndexAttempt.status == alias.status,
IndexAttempt.time_updated == alias.time_updated,
),
)
return cast(list[IndexAttempt], query.all())
def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
distinct_sources = db_session.query(Connector.source).distinct().all()
sources = [
source[0]
for source in distinct_sources
if source[0] != DocumentSource.INGESTION_API
]
return sources
def create_initial_default_connector(db_session: Session) -> None:
default_connector_id = 0
default_connector = fetch_connector_by_id(default_connector_id, db_session)
if default_connector is not None:
if (
default_connector.source != DocumentSource.INGESTION_API
or default_connector.input_type != InputType.LOAD_STATE
or default_connector.refresh_freq is not None
or default_connector.name != "Ingestion API"
or default_connector.connector_specific_config != {}
or default_connector.prune_freq is not None
):
logger.warning(
"Default connector does not have expected values. Updating to proper state."
)
# Ensure default connector has correct values
default_connector.source = DocumentSource.INGESTION_API
default_connector.input_type = InputType.LOAD_STATE
default_connector.refresh_freq = None
default_connector.name = "Ingestion API"
default_connector.connector_specific_config = {}
default_connector.prune_freq = None
db_session.commit()
return
# Create a new default connector if it doesn't exist
connector = Connector(
id=default_connector_id,
name="Ingestion API",
source=DocumentSource.INGESTION_API,
input_type=InputType.LOAD_STATE,
connector_specific_config={},
refresh_freq=None,
prune_freq=None,
)
db_session.add(connector)
db_session.commit()
def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.last_pruned = datetime.now(timezone.utc)
db_session.commit()
def mark_cc_pair_as_permissions_synced(
db_session: Session, cc_pair_id: int, start_time: datetime | None
) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.last_time_perm_sync = start_time
db_session.commit()
def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
# The sync time can be marked after it ran because all group syncs
# are run in full, not polling for changes.
# If this changes, we need to update this function.
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
db_session.commit()
def mark_ccpair_with_indexing_trigger(
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
) -> None:
"""indexing_mode sets a field which will be picked up by a background task
to trigger indexing. Set to None to disable the trigger."""
try:
cc_pair = db_session.execute(
select(ConnectorCredentialPair)
.where(ConnectorCredentialPair.id == cc_pair_id)
.with_for_update()
).scalar_one()
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.indexing_trigger = indexing_mode
db_session.commit()
except Exception:
db_session.rollback()
raise