mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-27 02:02:18 +01:00
337 lines
11 KiB
Python
337 lines
11 KiB
Python
from datetime import datetime
|
|
from datetime import timezone
|
|
from typing import cast
|
|
|
|
from sqlalchemy import and_
|
|
from sqlalchemy import exists
|
|
from sqlalchemy import func
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import aliased
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.configs.app_configs import DEFAULT_PRUNING_FREQ
|
|
from onyx.configs.constants import DocumentSource
|
|
from onyx.connectors.models import InputType
|
|
from onyx.db.enums import IndexingMode
|
|
from onyx.db.models import Connector
|
|
from onyx.db.models import ConnectorCredentialPair
|
|
from onyx.db.models import IndexAttempt
|
|
from onyx.server.documents.models import ConnectorBase
|
|
from onyx.server.documents.models import ObjectCreationIdResponse
|
|
from onyx.server.models import StatusResponse
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def check_connectors_exist(db_session: Session) -> bool:
|
|
# Connector 0 is created on server startup as a default for ingestion
|
|
# it will always exist and we don't need to count it for this
|
|
stmt = select(exists(Connector).where(Connector.id > 0))
|
|
result = db_session.execute(stmt)
|
|
return result.scalar() or False
|
|
|
|
|
|
def fetch_connectors(
|
|
db_session: Session,
|
|
sources: list[DocumentSource] | None = None,
|
|
input_types: list[InputType] | None = None,
|
|
) -> list[Connector]:
|
|
stmt = select(Connector)
|
|
if sources is not None:
|
|
stmt = stmt.where(Connector.source.in_(sources))
|
|
if input_types is not None:
|
|
stmt = stmt.where(Connector.input_type.in_(input_types))
|
|
results = db_session.scalars(stmt)
|
|
return list(results.all())
|
|
|
|
|
|
def connector_by_name_source_exists(
|
|
connector_name: str, source: DocumentSource, db_session: Session
|
|
) -> bool:
|
|
stmt = select(Connector).where(
|
|
Connector.name == connector_name, Connector.source == source
|
|
)
|
|
result = db_session.execute(stmt)
|
|
connector = result.scalar_one_or_none()
|
|
return connector is not None
|
|
|
|
|
|
def fetch_connector_by_id(connector_id: int, db_session: Session) -> Connector | None:
|
|
stmt = select(Connector).where(Connector.id == connector_id)
|
|
result = db_session.execute(stmt)
|
|
connector = result.scalar_one_or_none()
|
|
return connector
|
|
|
|
|
|
def fetch_ingestion_connector_by_name(
|
|
connector_name: str, db_session: Session
|
|
) -> Connector | None:
|
|
stmt = (
|
|
select(Connector)
|
|
.where(Connector.name == connector_name)
|
|
.where(Connector.source == DocumentSource.INGESTION_API)
|
|
)
|
|
result = db_session.execute(stmt)
|
|
connector = result.scalar_one_or_none()
|
|
return connector
|
|
|
|
|
|
def create_connector(
|
|
db_session: Session,
|
|
connector_data: ConnectorBase,
|
|
) -> ObjectCreationIdResponse:
|
|
if connector_by_name_source_exists(
|
|
connector_data.name, connector_data.source, db_session
|
|
):
|
|
raise ValueError(
|
|
"Connector by this name already exists, duplicate naming not allowed."
|
|
)
|
|
|
|
connector = Connector(
|
|
name=connector_data.name,
|
|
source=connector_data.source,
|
|
input_type=connector_data.input_type,
|
|
connector_specific_config=connector_data.connector_specific_config,
|
|
refresh_freq=connector_data.refresh_freq,
|
|
indexing_start=connector_data.indexing_start,
|
|
prune_freq=connector_data.prune_freq,
|
|
)
|
|
db_session.add(connector)
|
|
db_session.commit()
|
|
|
|
return ObjectCreationIdResponse(id=connector.id)
|
|
|
|
|
|
def update_connector(
|
|
connector_id: int,
|
|
connector_data: ConnectorBase,
|
|
db_session: Session,
|
|
) -> Connector | None:
|
|
connector = fetch_connector_by_id(connector_id, db_session)
|
|
if connector is None:
|
|
return None
|
|
|
|
if connector_data.name != connector.name and connector_by_name_source_exists(
|
|
connector_data.name, connector_data.source, db_session
|
|
):
|
|
raise ValueError(
|
|
"Connector by this name already exists, duplicate naming not allowed."
|
|
)
|
|
|
|
connector.name = connector_data.name
|
|
connector.source = connector_data.source
|
|
connector.input_type = connector_data.input_type
|
|
connector.connector_specific_config = connector_data.connector_specific_config
|
|
connector.refresh_freq = connector_data.refresh_freq
|
|
connector.prune_freq = (
|
|
connector_data.prune_freq
|
|
if connector_data.prune_freq is not None
|
|
else DEFAULT_PRUNING_FREQ
|
|
)
|
|
|
|
db_session.commit()
|
|
return connector
|
|
|
|
|
|
def delete_connector(
|
|
db_session: Session,
|
|
connector_id: int,
|
|
) -> StatusResponse[int]:
|
|
"""Only used in special cases (e.g. a connector is in a bad state and we need to delete it).
|
|
Be VERY careful using this, as it could lead to a bad state if not used correctly.
|
|
"""
|
|
connector = fetch_connector_by_id(connector_id, db_session)
|
|
if connector is None:
|
|
return StatusResponse(
|
|
success=True, message="Connector was already deleted", data=connector_id
|
|
)
|
|
|
|
db_session.delete(connector)
|
|
return StatusResponse(
|
|
success=True, message="Connector deleted successfully", data=connector_id
|
|
)
|
|
|
|
|
|
def get_connector_credential_ids(
|
|
connector_id: int,
|
|
db_session: Session,
|
|
) -> list[int]:
|
|
connector = fetch_connector_by_id(connector_id, db_session)
|
|
if connector is None:
|
|
raise ValueError(f"Connector by id {connector_id} does not exist")
|
|
|
|
return [association.credential.id for association in connector.credentials]
|
|
|
|
|
|
def fetch_latest_index_attempt_by_connector(
|
|
db_session: Session,
|
|
source: DocumentSource | None = None,
|
|
) -> list[IndexAttempt]:
|
|
latest_index_attempts: list[IndexAttempt] = []
|
|
|
|
if source:
|
|
connectors = fetch_connectors(db_session, sources=[source])
|
|
else:
|
|
connectors = fetch_connectors(db_session)
|
|
|
|
if not connectors:
|
|
return []
|
|
|
|
for connector in connectors:
|
|
latest_index_attempt = (
|
|
db_session.query(IndexAttempt)
|
|
.join(ConnectorCredentialPair)
|
|
.filter(ConnectorCredentialPair.connector_id == connector.id)
|
|
.order_by(IndexAttempt.time_updated.desc())
|
|
.first()
|
|
)
|
|
|
|
if latest_index_attempt is not None:
|
|
latest_index_attempts.append(latest_index_attempt)
|
|
|
|
return latest_index_attempts
|
|
|
|
|
|
def fetch_latest_index_attempts_by_status(
|
|
db_session: Session,
|
|
) -> list[IndexAttempt]:
|
|
subquery = (
|
|
db_session.query(
|
|
IndexAttempt.connector_credential_pair_id,
|
|
IndexAttempt.status,
|
|
func.max(IndexAttempt.time_updated).label("time_updated"),
|
|
)
|
|
.group_by(IndexAttempt.connector_credential_pair_id)
|
|
.group_by(IndexAttempt.status)
|
|
.subquery()
|
|
)
|
|
|
|
alias = aliased(IndexAttempt, subquery)
|
|
|
|
query = db_session.query(IndexAttempt).join(
|
|
alias,
|
|
and_(
|
|
IndexAttempt.connector_credential_pair_id
|
|
== alias.connector_credential_pair_id,
|
|
IndexAttempt.status == alias.status,
|
|
IndexAttempt.time_updated == alias.time_updated,
|
|
),
|
|
)
|
|
|
|
return cast(list[IndexAttempt], query.all())
|
|
|
|
|
|
def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
|
|
distinct_sources = db_session.query(Connector.source).distinct().all()
|
|
|
|
sources = [
|
|
source[0]
|
|
for source in distinct_sources
|
|
if source[0] != DocumentSource.INGESTION_API
|
|
]
|
|
|
|
return sources
|
|
|
|
|
|
def create_initial_default_connector(db_session: Session) -> None:
|
|
default_connector_id = 0
|
|
default_connector = fetch_connector_by_id(default_connector_id, db_session)
|
|
if default_connector is not None:
|
|
if (
|
|
default_connector.source != DocumentSource.INGESTION_API
|
|
or default_connector.input_type != InputType.LOAD_STATE
|
|
or default_connector.refresh_freq is not None
|
|
or default_connector.name != "Ingestion API"
|
|
or default_connector.connector_specific_config != {}
|
|
or default_connector.prune_freq is not None
|
|
):
|
|
logger.warning(
|
|
"Default connector does not have expected values. Updating to proper state."
|
|
)
|
|
# Ensure default connector has correct values
|
|
default_connector.source = DocumentSource.INGESTION_API
|
|
default_connector.input_type = InputType.LOAD_STATE
|
|
default_connector.refresh_freq = None
|
|
default_connector.name = "Ingestion API"
|
|
default_connector.connector_specific_config = {}
|
|
default_connector.prune_freq = None
|
|
db_session.commit()
|
|
return
|
|
|
|
# Create a new default connector if it doesn't exist
|
|
connector = Connector(
|
|
id=default_connector_id,
|
|
name="Ingestion API",
|
|
source=DocumentSource.INGESTION_API,
|
|
input_type=InputType.LOAD_STATE,
|
|
connector_specific_config={},
|
|
refresh_freq=None,
|
|
prune_freq=None,
|
|
)
|
|
db_session.add(connector)
|
|
db_session.commit()
|
|
|
|
|
|
def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
|
|
stmt = select(ConnectorCredentialPair).where(
|
|
ConnectorCredentialPair.id == cc_pair_id
|
|
)
|
|
cc_pair = db_session.scalar(stmt)
|
|
if cc_pair is None:
|
|
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
|
|
|
cc_pair.last_pruned = datetime.now(timezone.utc)
|
|
db_session.commit()
|
|
|
|
|
|
def mark_cc_pair_as_permissions_synced(
|
|
db_session: Session, cc_pair_id: int, start_time: datetime | None
|
|
) -> None:
|
|
stmt = select(ConnectorCredentialPair).where(
|
|
ConnectorCredentialPair.id == cc_pair_id
|
|
)
|
|
cc_pair = db_session.scalar(stmt)
|
|
if cc_pair is None:
|
|
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
|
|
|
cc_pair.last_time_perm_sync = start_time
|
|
db_session.commit()
|
|
|
|
|
|
def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) -> None:
|
|
stmt = select(ConnectorCredentialPair).where(
|
|
ConnectorCredentialPair.id == cc_pair_id
|
|
)
|
|
cc_pair = db_session.scalar(stmt)
|
|
if cc_pair is None:
|
|
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
|
|
|
# The sync time can be marked after it ran because all group syncs
|
|
# are run in full, not polling for changes.
|
|
# If this changes, we need to update this function.
|
|
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
|
|
db_session.commit()
|
|
|
|
|
|
def mark_ccpair_with_indexing_trigger(
|
|
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
|
|
) -> None:
|
|
"""indexing_mode sets a field which will be picked up by a background task
|
|
to trigger indexing. Set to None to disable the trigger."""
|
|
try:
|
|
cc_pair = db_session.execute(
|
|
select(ConnectorCredentialPair)
|
|
.where(ConnectorCredentialPair.id == cc_pair_id)
|
|
.with_for_update()
|
|
).scalar_one()
|
|
|
|
if cc_pair is None:
|
|
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
|
|
|
cc_pair.indexing_trigger = indexing_mode
|
|
db_session.commit()
|
|
except Exception:
|
|
db_session.rollback()
|
|
raise
|