mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
497 lines
16 KiB
Python
497 lines
16 KiB
Python
from typing import Any
|
|
|
|
from sqlalchemy import exists
|
|
from sqlalchemy import Select
|
|
from sqlalchemy import select
|
|
from sqlalchemy import update
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.sql.expression import and_
|
|
from sqlalchemy.sql.expression import or_
|
|
|
|
from onyx.auth.schemas import UserRole
|
|
from onyx.configs.app_configs import DISABLE_AUTH
|
|
from onyx.configs.constants import DocumentSource
|
|
from onyx.connectors.google_utils.shared_constants import (
|
|
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
|
)
|
|
from onyx.db.enums import ConnectorCredentialPairStatus
|
|
from onyx.db.models import ConnectorCredentialPair
|
|
from onyx.db.models import Credential
|
|
from onyx.db.models import Credential__UserGroup
|
|
from onyx.db.models import DocumentByConnectorCredentialPair
|
|
from onyx.db.models import User
|
|
from onyx.db.models import User__UserGroup
|
|
from onyx.server.documents.models import CredentialBase
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
# The credentials for these sources are not real so
|
|
# permissions are not enforced for them
|
|
CREDENTIAL_PERMISSIONS_TO_IGNORE = {
|
|
DocumentSource.FILE,
|
|
DocumentSource.WEB,
|
|
DocumentSource.NOT_APPLICABLE,
|
|
DocumentSource.GOOGLE_SITES,
|
|
DocumentSource.WIKIPEDIA,
|
|
DocumentSource.MEDIAWIKI,
|
|
}
|
|
|
|
PUBLIC_CREDENTIAL_ID = 0
|
|
|
|
|
|
def _add_user_filters(
|
|
stmt: Select,
|
|
user: User | None,
|
|
get_editable: bool = True,
|
|
) -> Select:
|
|
"""Attaches filters to the statement to ensure that the user can only
|
|
access the appropriate credentials"""
|
|
if user is None:
|
|
if not DISABLE_AUTH:
|
|
raise ValueError("Anonymous users are not allowed to access credentials")
|
|
# If user is None and auth is disabled, assume the user is an admin
|
|
return stmt.where(
|
|
or_(
|
|
Credential.user_id.is_(None),
|
|
Credential.admin_public == True, # noqa: E712
|
|
Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE),
|
|
)
|
|
)
|
|
|
|
if user.role == UserRole.ADMIN:
|
|
# Admins can access all credentials that are public or owned by them
|
|
# or are not associated with any user
|
|
return stmt.where(
|
|
or_(
|
|
Credential.user_id == user.id,
|
|
Credential.user_id.is_(None),
|
|
Credential.admin_public == True, # noqa: E712
|
|
Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE),
|
|
)
|
|
)
|
|
if user.role == UserRole.BASIC:
|
|
# Basic users can only access credentials that are owned by them
|
|
return stmt.where(Credential.user_id == user.id)
|
|
|
|
stmt = stmt.distinct()
|
|
"""
|
|
THIS PART IS FOR CURATORS AND GLOBAL CURATORS
|
|
Here we select cc_pairs by relation:
|
|
User -> User__UserGroup -> Credential__UserGroup -> Credential
|
|
"""
|
|
stmt = stmt.outerjoin(Credential__UserGroup).outerjoin(
|
|
User__UserGroup,
|
|
User__UserGroup.user_group_id == Credential__UserGroup.user_group_id,
|
|
)
|
|
"""
|
|
Filter Credentials by:
|
|
- if the user is in the user_group that owns the Credential
|
|
- if the user is a curator, they must also have a curator relationship
|
|
to the user_group
|
|
- if editing is being done, we also filter out Credentials that are owned by groups
|
|
that the user isn't a curator for
|
|
- if we are not editing, we show all Credentials in the groups the user is a curator
|
|
for (as well as public Credentials)
|
|
- if we are not editing, we return all Credentials directly connected to the user
|
|
"""
|
|
where_clause = User__UserGroup.user_id == user.id
|
|
if user.role == UserRole.CURATOR:
|
|
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
|
|
|
if get_editable:
|
|
user_groups = select(User__UserGroup.user_group_id).where(
|
|
User__UserGroup.user_id == user.id
|
|
)
|
|
if user.role == UserRole.CURATOR:
|
|
user_groups = user_groups.where(
|
|
User__UserGroup.is_curator == True # noqa: E712
|
|
)
|
|
where_clause &= (
|
|
~exists()
|
|
.where(Credential__UserGroup.credential_id == Credential.id)
|
|
.where(~Credential__UserGroup.user_group_id.in_(user_groups))
|
|
.correlate(Credential)
|
|
)
|
|
else:
|
|
where_clause |= Credential.curator_public == True # noqa: E712
|
|
where_clause |= Credential.user_id == user.id # noqa: E712
|
|
|
|
where_clause |= Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE)
|
|
|
|
return stmt.where(where_clause)
|
|
|
|
|
|
def _relate_credential_to_user_groups__no_commit(
|
|
db_session: Session,
|
|
credential_id: int,
|
|
user_group_ids: list[int],
|
|
) -> None:
|
|
credential_user_groups = []
|
|
for group_id in user_group_ids:
|
|
credential_user_groups.append(
|
|
Credential__UserGroup(
|
|
credential_id=credential_id,
|
|
user_group_id=group_id,
|
|
)
|
|
)
|
|
db_session.add_all(credential_user_groups)
|
|
|
|
|
|
def fetch_credentials_for_user(
|
|
db_session: Session,
|
|
user: User | None,
|
|
get_editable: bool = True,
|
|
) -> list[Credential]:
|
|
stmt = select(Credential)
|
|
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
|
|
results = db_session.scalars(stmt)
|
|
return list(results.all())
|
|
|
|
|
|
def fetch_credential_by_id_for_user(
|
|
credential_id: int,
|
|
user: User | None,
|
|
db_session: Session,
|
|
get_editable: bool = True,
|
|
) -> Credential | None:
|
|
stmt = select(Credential).distinct()
|
|
stmt = stmt.where(Credential.id == credential_id)
|
|
stmt = _add_user_filters(
|
|
stmt=stmt,
|
|
user=user,
|
|
get_editable=get_editable,
|
|
)
|
|
result = db_session.execute(stmt)
|
|
credential = result.scalar_one_or_none()
|
|
return credential
|
|
|
|
|
|
def fetch_credential_by_id(
|
|
credential_id: int,
|
|
db_session: Session,
|
|
) -> Credential | None:
|
|
stmt = select(Credential).distinct()
|
|
stmt = stmt.where(Credential.id == credential_id)
|
|
result = db_session.execute(stmt)
|
|
credential = result.scalar_one_or_none()
|
|
return credential
|
|
|
|
|
|
def fetch_credentials_by_source_for_user(
|
|
db_session: Session,
|
|
user: User | None,
|
|
document_source: DocumentSource | None = None,
|
|
get_editable: bool = True,
|
|
) -> list[Credential]:
|
|
base_query = select(Credential).where(Credential.source == document_source)
|
|
base_query = _add_user_filters(base_query, user, get_editable=get_editable)
|
|
credentials = db_session.execute(base_query).scalars().all()
|
|
return list(credentials)
|
|
|
|
|
|
def fetch_credentials_by_source(
|
|
db_session: Session,
|
|
document_source: DocumentSource | None = None,
|
|
) -> list[Credential]:
|
|
base_query = select(Credential).where(Credential.source == document_source)
|
|
credentials = db_session.execute(base_query).scalars().all()
|
|
return list(credentials)
|
|
|
|
|
|
def swap_credentials_connector(
|
|
new_credential_id: int, connector_id: int, user: User | None, db_session: Session
|
|
) -> ConnectorCredentialPair:
|
|
# Check if the user has permission to use the new credential
|
|
new_credential = fetch_credential_by_id_for_user(
|
|
new_credential_id, user, db_session
|
|
)
|
|
if not new_credential:
|
|
raise ValueError(
|
|
f"No Credential found with id {new_credential_id} or user doesn't have permission to use it"
|
|
)
|
|
|
|
# Existing pair
|
|
existing_pair = db_session.execute(
|
|
select(ConnectorCredentialPair).where(
|
|
ConnectorCredentialPair.connector_id == connector_id
|
|
)
|
|
).scalar_one_or_none()
|
|
|
|
if not existing_pair:
|
|
raise ValueError(
|
|
f"No ConnectorCredentialPair found for connector_id {connector_id}"
|
|
)
|
|
|
|
# Check if the new credential is compatible with the connector
|
|
if new_credential.source != existing_pair.connector.source:
|
|
raise ValueError(
|
|
f"New credential source {new_credential.source} does not match connector source {existing_pair.connector.source}"
|
|
)
|
|
|
|
db_session.execute(
|
|
update(DocumentByConnectorCredentialPair)
|
|
.where(
|
|
and_(
|
|
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
|
DocumentByConnectorCredentialPair.credential_id
|
|
== existing_pair.credential_id,
|
|
)
|
|
)
|
|
.values(credential_id=new_credential_id)
|
|
)
|
|
|
|
# Update the existing pair with the new credential
|
|
existing_pair.credential_id = new_credential_id
|
|
existing_pair.credential = new_credential
|
|
|
|
# Update ccpair status if it's in INVALID state
|
|
if existing_pair.status == ConnectorCredentialPairStatus.INVALID:
|
|
existing_pair.status = ConnectorCredentialPairStatus.ACTIVE
|
|
|
|
# Commit the changes
|
|
db_session.commit()
|
|
|
|
# Refresh the object to ensure all relationships are up-to-date
|
|
db_session.refresh(existing_pair)
|
|
return existing_pair
|
|
|
|
|
|
def create_credential(
|
|
credential_data: CredentialBase,
|
|
user: User | None,
|
|
db_session: Session,
|
|
) -> Credential:
|
|
credential = Credential(
|
|
credential_json=credential_data.credential_json,
|
|
user_id=user.id if user else None,
|
|
admin_public=credential_data.admin_public,
|
|
source=credential_data.source,
|
|
name=credential_data.name,
|
|
curator_public=credential_data.curator_public,
|
|
)
|
|
db_session.add(credential)
|
|
db_session.flush() # This ensures the credential gets an ID
|
|
_relate_credential_to_user_groups__no_commit(
|
|
db_session=db_session,
|
|
credential_id=credential.id,
|
|
user_group_ids=credential_data.groups,
|
|
)
|
|
|
|
db_session.commit()
|
|
return credential
|
|
|
|
|
|
def _cleanup_credential__user_group_relationships__no_commit(
|
|
db_session: Session, credential_id: int
|
|
) -> None:
|
|
"""NOTE: does not commit the transaction."""
|
|
db_session.query(Credential__UserGroup).filter(
|
|
Credential__UserGroup.credential_id == credential_id
|
|
).delete(synchronize_session=False)
|
|
|
|
|
|
def alter_credential(
|
|
credential_id: int,
|
|
name: str,
|
|
credential_json: dict[str, Any],
|
|
user: User,
|
|
db_session: Session,
|
|
) -> Credential | None:
|
|
# TODO: add user group relationship update
|
|
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
|
|
|
if credential is None:
|
|
return None
|
|
|
|
credential.name = name
|
|
|
|
# Assign a new dictionary to credential.credential_json
|
|
credential.credential_json = {
|
|
**credential.credential_json,
|
|
**credential_json,
|
|
}
|
|
|
|
credential.user_id = user.id if user is not None else None
|
|
db_session.commit()
|
|
return credential
|
|
|
|
|
|
def update_credential(
|
|
credential_id: int,
|
|
credential_data: CredentialBase,
|
|
user: User,
|
|
db_session: Session,
|
|
) -> Credential | None:
|
|
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
|
if credential is None:
|
|
return None
|
|
|
|
credential.credential_json = credential_data.credential_json
|
|
credential.user_id = user.id if user is not None else None
|
|
|
|
db_session.commit()
|
|
return credential
|
|
|
|
|
|
def update_credential_json(
|
|
credential_id: int,
|
|
credential_json: dict[str, Any],
|
|
user: User,
|
|
db_session: Session,
|
|
) -> Credential | None:
|
|
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
|
if credential is None:
|
|
return None
|
|
|
|
credential.credential_json = credential_json
|
|
db_session.commit()
|
|
return credential
|
|
|
|
|
|
def backend_update_credential_json(
|
|
credential: Credential,
|
|
credential_json: dict[str, Any],
|
|
db_session: Session,
|
|
) -> None:
|
|
"""This should not be used in any flows involving the frontend or users"""
|
|
credential.credential_json = credential_json
|
|
db_session.commit()
|
|
|
|
|
|
def _delete_credential_internal(
|
|
credential: Credential,
|
|
credential_id: int,
|
|
db_session: Session,
|
|
force: bool = False,
|
|
) -> None:
|
|
"""Internal utility function to handle the actual deletion of a credential"""
|
|
associated_connectors = (
|
|
db_session.query(ConnectorCredentialPair)
|
|
.filter(ConnectorCredentialPair.credential_id == credential_id)
|
|
.all()
|
|
)
|
|
|
|
associated_doc_cc_pairs = (
|
|
db_session.query(DocumentByConnectorCredentialPair)
|
|
.filter(DocumentByConnectorCredentialPair.credential_id == credential_id)
|
|
.all()
|
|
)
|
|
|
|
if associated_connectors or associated_doc_cc_pairs:
|
|
if force:
|
|
logger.warning(
|
|
f"Force deleting credential {credential_id} and its associated records"
|
|
)
|
|
|
|
# Delete DocumentByConnectorCredentialPair records first
|
|
for doc_cc_pair in associated_doc_cc_pairs:
|
|
db_session.delete(doc_cc_pair)
|
|
|
|
# Then delete ConnectorCredentialPair records
|
|
for connector in associated_connectors:
|
|
db_session.delete(connector)
|
|
|
|
# Commit these deletions before deleting the credential
|
|
db_session.flush()
|
|
else:
|
|
raise ValueError(
|
|
f"Cannot delete credential as it is still associated with "
|
|
f"{len(associated_connectors)} connector(s) and {len(associated_doc_cc_pairs)} document(s). "
|
|
)
|
|
|
|
if force:
|
|
logger.warning(f"Force deleting credential {credential_id}")
|
|
else:
|
|
logger.notice(f"Deleting credential {credential_id}")
|
|
|
|
_cleanup_credential__user_group_relationships__no_commit(db_session, credential_id)
|
|
db_session.delete(credential)
|
|
db_session.commit()
|
|
|
|
|
|
def delete_credential_for_user(
|
|
credential_id: int,
|
|
user: User,
|
|
db_session: Session,
|
|
force: bool = False,
|
|
) -> None:
|
|
"""Delete a credential that belongs to a specific user"""
|
|
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
|
if credential is None:
|
|
raise ValueError(
|
|
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
|
)
|
|
|
|
_delete_credential_internal(credential, credential_id, db_session, force)
|
|
|
|
|
|
def delete_credential(
|
|
credential_id: int,
|
|
db_session: Session,
|
|
force: bool = False,
|
|
) -> None:
|
|
"""Delete a credential regardless of ownership (admin function)"""
|
|
credential = fetch_credential_by_id(credential_id, db_session)
|
|
if credential is None:
|
|
raise ValueError(f"Credential by provided id {credential_id} does not exist")
|
|
|
|
_delete_credential_internal(credential, credential_id, db_session, force)
|
|
|
|
|
|
def create_initial_public_credential(db_session: Session) -> None:
|
|
error_msg = (
|
|
"DB is not in a valid initial state."
|
|
"There must exist an empty public credential for data connectors that do not require additional Auth."
|
|
)
|
|
first_credential = fetch_credential_by_id(
|
|
credential_id=PUBLIC_CREDENTIAL_ID,
|
|
db_session=db_session,
|
|
)
|
|
|
|
if first_credential is not None:
|
|
if first_credential.credential_json != {} or first_credential.user is not None:
|
|
raise ValueError(error_msg)
|
|
return
|
|
|
|
credential = Credential(
|
|
id=PUBLIC_CREDENTIAL_ID,
|
|
credential_json={},
|
|
user_id=None,
|
|
)
|
|
db_session.add(credential)
|
|
db_session.commit()
|
|
|
|
|
|
def cleanup_gmail_credentials(db_session: Session) -> None:
|
|
gmail_credentials = fetch_credentials_by_source(
|
|
db_session=db_session, document_source=DocumentSource.GMAIL
|
|
)
|
|
for credential in gmail_credentials:
|
|
db_session.delete(credential)
|
|
db_session.commit()
|
|
|
|
|
|
def cleanup_google_drive_credentials(db_session: Session) -> None:
|
|
google_drive_credentials = fetch_credentials_by_source(
|
|
db_session=db_session, document_source=DocumentSource.GOOGLE_DRIVE
|
|
)
|
|
for credential in google_drive_credentials:
|
|
db_session.delete(credential)
|
|
db_session.commit()
|
|
|
|
|
|
def delete_service_account_credentials(
|
|
user: User | None, db_session: Session, source: DocumentSource
|
|
) -> None:
|
|
credentials = fetch_credentials_for_user(db_session=db_session, user=user)
|
|
for credential in credentials:
|
|
if (
|
|
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
|
|
and credential.source == source
|
|
):
|
|
db_session.delete(credential)
|
|
|
|
db_session.commit()
|