From db8ce61ff4f056fa2656d386ba390773bb12b7cf Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 27 Oct 2023 12:13:52 -0700 Subject: [PATCH] Fix group prefix (#11) --- backend/danswer/access/access.py | 9 +---- backend/danswer/access/models.py | 14 ++++++- backend/danswer/access/utils.py | 10 +++++ backend/danswer/utils/acl.py | 10 ++--- backend/ee/danswer/access/access.py | 40 +++++-------------- .../ee/danswer/background/celery/celery.py | 13 +++++- backend/ee/danswer/user_groups/sync.py | 2 +- 7 files changed, 51 insertions(+), 47 deletions(-) create mode 100644 backend/danswer/access/utils.py diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py index 8b46c711e..51f5a300c 100644 --- a/backend/danswer/access/access.py +++ b/backend/danswer/access/access.py @@ -1,6 +1,7 @@ from sqlalchemy.orm import Session from danswer.access.models import DocumentAccess +from danswer.access.utils import prefix_user from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.db.document import get_acccess_info_for_documents from danswer.db.models import User @@ -19,7 +20,7 @@ def _get_access_for_documents( cc_pair_to_delete=cc_pair_to_delete, ) return { - document_id: DocumentAccess.build(user_ids, is_public) + document_id: DocumentAccess.build(user_ids, [], is_public) for document_id, user_ids, is_public in document_access_info } @@ -38,12 +39,6 @@ def get_access_for_documents( ) # type: ignore -def prefix_user(user_id: str) -> str: - """Prefixes a user ID to eliminate collision with group names. - This assumes that groups are prefixed with a different prefix.""" - return f"user_id:{user_id}" - - def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: """Returns a list of ACL entries that the user has access to. This is meant to be used downstream to filter out documents that the user does not have access to. The diff --git a/backend/danswer/access/models.py b/backend/danswer/access/models.py index 94a18528c..a87e2d94f 100644 --- a/backend/danswer/access/models.py +++ b/backend/danswer/access/models.py @@ -1,20 +1,30 @@ from dataclasses import dataclass from uuid import UUID +from danswer.access.utils import prefix_user +from danswer.access.utils import prefix_user_group from danswer.configs.constants import PUBLIC_DOC_PAT @dataclass(frozen=True) class DocumentAccess: user_ids: set[str] # stringified UUIDs + user_groups: set[str] # names of user groups associated with this document is_public: bool def to_acl(self) -> list[str]: - return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else []) + return ( + [prefix_user(user_id) for user_id in self.user_ids] + + [prefix_user_group(group_name) for group_name in self.user_groups] + + ([PUBLIC_DOC_PAT] if self.is_public else []) + ) @classmethod - def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess": + def build( + cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool + ) -> "DocumentAccess": return cls( user_ids={str(user_id) for user_id in user_ids if user_id}, + user_groups=set(user_groups), is_public=is_public, ) diff --git a/backend/danswer/access/utils.py b/backend/danswer/access/utils.py new file mode 100644 index 000000000..060560eae --- /dev/null +++ b/backend/danswer/access/utils.py @@ -0,0 +1,10 @@ +def prefix_user(user_id: str) -> str: + """Prefixes a user ID to eliminate collision with group names. + This assumes that groups are prefixed with a different prefix.""" + return f"user_id:{user_id}" + + +def prefix_user_group(user_group_name: str) -> str: + """Prefixes a user group name to eliminate collision with user IDs. + This assumes that user ids are prefixed with a different prefix.""" + return f"group:{user_group_name}" diff --git a/backend/danswer/utils/acl.py b/backend/danswer/utils/acl.py index 8fbadb300..5608530fa 100644 --- a/backend/danswer/utils/acl.py +++ b/backend/danswer/utils/acl.py @@ -3,8 +3,7 @@ from threading import Thread from sqlalchemy import select from sqlalchemy.orm import Session -from danswer.access.models import DocumentAccess -from danswer.db.document import get_acccess_info_for_documents +from danswer.access.access import get_access_for_documents from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import Document from danswer.document_index.document_index_utils import get_both_index_names @@ -37,7 +36,7 @@ def set_acl_for_vespa(should_check_if_already_done: bool = False) -> None: # for all documents, set the `access_control_list` field appropriately # based on the state of Postgres documents = db_session.scalars(select(Document)).all() - document_access_info = get_acccess_info_for_documents( + document_access_dict = get_access_for_documents( db_session=db_session, document_ids=[document.id for document in documents], ) @@ -46,15 +45,16 @@ def set_acl_for_vespa(should_check_if_already_done: bool = False) -> None: vespa_index = get_default_document_index( primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name ) + if not isinstance(vespa_index, VespaIndex): raise ValueError("This script is only for Vespa indexes") update_requests = [ UpdateRequest( document_ids=[document_id], - access=DocumentAccess.build(user_ids, is_public), + access=access, ) - for document_id, user_ids, is_public in document_access_info + for document_id, access in document_access_dict.items() ] vespa_index.update(update_requests=update_requests) diff --git a/backend/ee/danswer/access/access.py b/backend/ee/danswer/access/access.py index 308d56e37..254e76e66 100644 --- a/backend/ee/danswer/access/access.py +++ b/backend/ee/danswer/access/access.py @@ -1,11 +1,11 @@ from sqlalchemy.orm import Session -from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups from danswer.access.access import ( - get_access_for_documents as get_access_for_documents_without_groups, + _get_access_for_documents as get_access_for_documents_without_groups, ) +from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups from danswer.access.models import DocumentAccess -from danswer.db.engine import get_sqlalchemy_engine +from danswer.access.utils import prefix_user_group from danswer.db.models import User from danswer.server.documents.models import ConnectorCredentialPairIdentifier from ee.danswer.db.user_group import fetch_user_groups_for_documents @@ -14,13 +14,13 @@ from ee.danswer.db.user_group import fetch_user_groups_for_user def _get_access_for_documents( document_ids: list[str], - cc_pair_to_delete: ConnectorCredentialPairIdentifier | None, db_session: Session, + cc_pair_to_delete: ConnectorCredentialPairIdentifier | None, ) -> dict[str, DocumentAccess]: - access_dict = get_access_for_documents_without_groups( + non_ee_access_dict = get_access_for_documents_without_groups( document_ids=document_ids, - cc_pair_to_delete=cc_pair_to_delete, db_session=db_session, + cc_pair_to_delete=cc_pair_to_delete, ) user_group_info = { document_id: group_names @@ -31,36 +31,16 @@ def _get_access_for_documents( ) } - # overload user_ids a bit - use it for both actual User IDs + group IDs return { document_id: DocumentAccess( - user_ids=access.user_ids.union(user_group_info.get(document_id, [])), # type: ignore - is_public=access.is_public, + user_ids=non_ee_access.user_ids, + user_groups=user_group_info.get(document_id, []), # type: ignore + is_public=non_ee_access.is_public, ) - for document_id, access in access_dict.items() + for document_id, non_ee_access in non_ee_access_dict.items() } -def get_access_for_documents( - document_ids: list[str], - cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None, - db_session: Session | None = None, -) -> dict[str, DocumentAccess]: - if db_session is None: - with Session(get_sqlalchemy_engine()) as db_session: - return _get_access_for_documents( - document_ids, cc_pair_to_delete, db_session - ) - - return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session) - - -def prefix_user_group(user_group_name: str) -> str: - """Prefixes a user group name to eliminate collision with user IDs. - This assumes that user ids are prefixed with a different prefix.""" - return f"group:{user_group_name}" - - def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: """Returns a list of ACL entries that the user has access to. This is meant to be used downstream to filter out documents that the user does not have access to. The diff --git a/backend/ee/danswer/background/celery/celery.py b/backend/ee/danswer/background/celery/celery.py index 0a742bd86..16de00da2 100644 --- a/backend/ee/danswer/background/celery/celery.py +++ b/backend/ee/danswer/background/celery/celery.py @@ -11,12 +11,16 @@ from danswer.db.tasks import mark_task_finished from danswer.db.tasks import mark_task_start from danswer.db.tasks import register_task from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import global_version from ee.danswer.background.user_group_sync import name_user_group_sync_task from ee.danswer.db.user_group import fetch_user_groups from ee.danswer.user_groups.sync import sync_user_groups logger = setup_logger() +# mark as EE for all tasks in this file +global_version.set_ee() + @celery_app.task(soft_time_limit=JOB_TIMEOUT) def sync_user_group_task(user_group_id: int) -> None: @@ -25,9 +29,14 @@ def sync_user_group_task(user_group_id: int) -> None: mark_task_start(task_name, db_session) # actual sync logic - sync_user_groups(user_group_id=user_group_id, db_session=db_session) + error_msg = None + try: + sync_user_groups(user_group_id=user_group_id, db_session=db_session) + except Exception as e: + error_msg = str(e) + logger.exception(f"Failed to sync user group - {error_msg}") - mark_task_finished(task_name, db_session) + mark_task_finished(task_name, db_session, success=error_msg is None) ##### diff --git a/backend/ee/danswer/user_groups/sync.py b/backend/ee/danswer/user_groups/sync.py index 91de86431..948e50452 100644 --- a/backend/ee/danswer/user_groups/sync.py +++ b/backend/ee/danswer/user_groups/sync.py @@ -3,11 +3,11 @@ from sqlalchemy.orm import Session from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import UpdateRequest +from danswer.access.access import get_access_for_documents from danswer.db.document import prepare_to_modify_documents from danswer.db.engine import get_sqlalchemy_engine from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger -from ee.danswer.access.access import get_access_for_documents from ee.danswer.db.user_group import delete_user_group from ee.danswer.db.user_group import fetch_documents_for_user_group from ee.danswer.db.user_group import fetch_user_group