EE Connector Deletion Bugfix + Refactor (#2042)

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
This commit is contained in:
Nathan Schwerdfeger
2024-08-11 20:33:07 -07:00
committed by GitHub
parent 79523f2e0a
commit c7e5b11c63
49 changed files with 998 additions and 800 deletions

View File

@@ -7,7 +7,6 @@ from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user_group
from danswer.db.models import User
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from ee.danswer.db.user_group import fetch_user_groups_for_documents
from ee.danswer.db.user_group import fetch_user_groups_for_user
@@ -15,19 +14,16 @@ from ee.danswer.db.user_group import fetch_user_groups_for_user
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
) -> dict[str, DocumentAccess]:
non_ee_access_dict = get_access_for_documents_without_groups(
document_ids=document_ids,
db_session=db_session,
cc_pair_to_delete=cc_pair_to_delete,
)
user_group_info = {
document_id: group_names
for document_id, group_names in fetch_user_groups_for_documents(
db_session=db_session,
document_ids=document_ids,
cc_pair_to_delete=cc_pair_to_delete,
)
}

View File

@@ -1,13 +1,36 @@
from sqlalchemy import delete
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _delete_connector_credential_pair_user_groups_relationship__no_commit(
db_session: Session, connector_id: int, credential_id: int
) -> None:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if cc_pair is None:
raise ValueError(
f"ConnectorCredentialPair with connector_id: {connector_id} "
f"and credential_id: {credential_id} not found"
)
stmt = delete(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.cc_pair_id == cc_pair.id,
)
db_session.execute(stmt)
def get_cc_pairs_by_source(
source_type: DocumentSource,
db_session: Session,

View File

@@ -2,10 +2,13 @@ from collections.abc import Sequence
from operator import and_
from uuid import UUID
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
@@ -15,7 +18,6 @@ from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from ee.danswer.server.user_group.models import UserGroupCreate
from ee.danswer.server.user_group.models import UserGroupUpdate
@@ -90,7 +92,6 @@ def fetch_documents_for_user_group_paginated(
def fetch_user_groups_for_documents(
db_session: Session,
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> Sequence[tuple[int, list[str]]]:
stmt = (
select(Document.id, func.array_agg(UserGroup.name))
@@ -114,19 +115,12 @@ def fetch_user_groups_for_documents(
.join(Document, Document.id == DocumentByConnectorCredentialPair.id)
.where(Document.id.in_(document_ids))
.where(UserGroup__ConnectorCredentialPair.is_current == True) # noqa: E712
# don't include CC pairs that are being deleted
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
.group_by(Document.id)
)
# pretend that the specified cc pair doesn't exist
if cc_pair_to_delete is not None:
stmt = stmt.where(
and_(
ConnectorCredentialPair.connector_id != cc_pair_to_delete.connector_id,
ConnectorCredentialPair.credential_id
!= cc_pair_to_delete.credential_id,
)
)
return db_session.execute(stmt).all() # type: ignore
@@ -343,3 +337,25 @@ def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
db_session.delete(user_group)
db_session.commit()
def delete_user_group_cc_pair_relationship__no_commit(
cc_pair_id: int, db_session: Session
) -> None:
"""Deletes all rows from UserGroup__ConnectorCredentialPair where the
connector_credential_pair_id matches the given cc_pair_id.
Should be used very carefully (only for connectors that are being deleted)."""
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
raise ValueError(
f"Connector Credential Pair '{cc_pair_id}' is not in the DELETING state"
)
delete_stmt = delete(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.cc_pair_id == cc_pair_id,
)
db_session.execute(delete_stmt)