From aab777f8447a8f22ffaf8cdfd90145be663fed01 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Thu, 27 Mar 2025 22:52:35 -0700 Subject: [PATCH] Bugfix/acl prefix (#4377) * fix acl prefixing * increase timeout a tad * block access to init'ing DocumentAccess directly, fix test to work with ee/MIT * fix env var checks --------- Co-authored-by: Richard Kuo (Onyx) --- backend/ee/onyx/access/access.py | 10 +- backend/onyx/access/access.py | 26 +-- backend/onyx/access/models.py | 97 ++++++----- .../query_time_check/seed_dummy_docs.py | 18 +-- .../integration/common_utils/constants.py | 2 +- .../connector/test_connector_deletion.py | 152 +++++++++++------- 6 files changed, 176 insertions(+), 129 deletions(-) diff --git a/backend/ee/onyx/access/access.py b/backend/ee/onyx/access/access.py index 558699d61..421fd7000 100644 --- a/backend/ee/onyx/access/access.py +++ b/backend/ee/onyx/access/access.py @@ -93,12 +93,12 @@ def _get_access_for_documents( ) # To avoid collisions of group namings between connectors, they need to be prefixed - access_map[document_id] = DocumentAccess( - user_emails=non_ee_access.user_emails, - user_groups=set(user_group_info.get(document_id, [])), + access_map[document_id] = DocumentAccess.build( + user_emails=list(non_ee_access.user_emails), + user_groups=user_group_info.get(document_id, []), is_public=is_public_anywhere, - external_user_emails=ext_u_emails, - external_user_group_ids=ext_u_groups, + external_user_emails=list(ext_u_emails), + external_user_group_ids=list(ext_u_groups), ) return access_map diff --git a/backend/onyx/access/access.py b/backend/onyx/access/access.py index ec1fd99d3..8c3c36416 100644 --- a/backend/onyx/access/access.py +++ b/backend/onyx/access/access.py @@ -18,7 +18,7 @@ def _get_access_for_document( document_id=document_id, ) - return DocumentAccess.build( + doc_access = DocumentAccess.build( user_emails=info[1] if info and info[1] else [], user_groups=[], external_user_emails=[], @@ -26,6 +26,8 @@ def _get_access_for_document( is_public=info[2] if info else False, ) + return doc_access + def get_access_for_document( document_id: str, @@ -38,12 +40,12 @@ def get_access_for_document( def get_null_document_access() -> DocumentAccess: - return DocumentAccess( - user_emails=set(), - user_groups=set(), + return DocumentAccess.build( + user_emails=[], + user_groups=[], is_public=False, - external_user_emails=set(), - external_user_group_ids=set(), + external_user_emails=[], + external_user_group_ids=[], ) @@ -56,18 +58,18 @@ def _get_access_for_documents( document_ids=document_ids, ) doc_access = { - document_id: DocumentAccess( - user_emails=set([email for email in user_emails if email]), + document_id: DocumentAccess.build( + user_emails=[email for email in user_emails if email], # MIT version will wipe all groups and external groups on update - user_groups=set(), + user_groups=[], is_public=is_public, - external_user_emails=set(), - external_user_group_ids=set(), + external_user_emails=[], + external_user_group_ids=[], ) for document_id, user_emails, is_public in document_access_info } - # Sometimes the document has not be indexed by the indexing job yet, in those cases + # Sometimes the document has not been indexed by the indexing job yet, in those cases # the document does not exist and so we use least permissive. Specifically the EE version # checks the MIT version permissions and creates a superset. This ensures that this flow # does not fail even if the Document has not yet been indexed. diff --git a/backend/onyx/access/models.py b/backend/onyx/access/models.py index 411e53a03..f1f5adf6e 100644 --- a/backend/onyx/access/models.py +++ b/backend/onyx/access/models.py @@ -56,34 +56,46 @@ class DocExternalAccess: ) -@dataclass(frozen=True) +@dataclass(frozen=True, init=False) class DocumentAccess(ExternalAccess): # User emails for Onyx users, None indicates admin user_emails: set[str | None] + # Names of user groups associated with this document user_groups: set[str] - def to_acl(self) -> set[str]: - return set( - [ - prefix_user_email(user_email) - for user_email in self.user_emails - if user_email - ] - + [prefix_user_group(group_name) for group_name in self.user_groups] - + [ - prefix_user_email(user_email) - for user_email in self.external_user_emails - ] - + [ - # The group names are already prefixed by the source type - # This adds an additional prefix of "external_group:" - prefix_external_group(group_name) - for group_name in self.external_user_group_ids - ] - + ([PUBLIC_DOC_PAT] if self.is_public else []) + external_user_emails: set[str] + external_user_group_ids: set[str] + is_public: bool + + def __init__(self) -> None: + raise TypeError( + "Use `DocumentAccess.build(...)` instead of creating an instance directly." ) + def to_acl(self) -> set[str]: + # the acl's emitted by this function are prefixed by type + # to get the native objects, access the member variables directly + + acl_set: set[str] = set() + for user_email in self.user_emails: + if user_email: + acl_set.add(prefix_user_email(user_email)) + + for group_name in self.user_groups: + acl_set.add(prefix_user_group(group_name)) + + for external_user_email in self.external_user_emails: + acl_set.add(prefix_user_email(external_user_email)) + + for external_group_id in self.external_user_group_ids: + acl_set.add(prefix_external_group(external_group_id)) + + if self.is_public: + acl_set.add(PUBLIC_DOC_PAT) + + return acl_set + @classmethod def build( cls, @@ -93,29 +105,32 @@ class DocumentAccess(ExternalAccess): external_user_group_ids: list[str], is_public: bool, ) -> "DocumentAccess": - return cls( - external_user_emails={ - prefix_user_email(external_email) - for external_email in external_user_emails - }, - external_user_group_ids={ - prefix_external_group(external_group_id) - for external_group_id in external_user_group_ids - }, - user_emails={ - prefix_user_email(user_email) - for user_email in user_emails - if user_email - }, - user_groups=set(user_groups), - is_public=is_public, + """Don't prefix incoming data wth acl type, prefix on read from to_acl!""" + + obj = object.__new__(cls) + object.__setattr__( + obj, "user_emails", {user_email for user_email in user_emails if user_email} ) + object.__setattr__(obj, "user_groups", set(user_groups)) + object.__setattr__( + obj, + "external_user_emails", + {external_email for external_email in external_user_emails}, + ) + object.__setattr__( + obj, + "external_user_group_ids", + {external_group_id for external_group_id in external_user_group_ids}, + ) + object.__setattr__(obj, "is_public", is_public) + + return obj -default_public_access = DocumentAccess( - external_user_emails=set(), - external_user_group_ids=set(), - user_emails=set(), - user_groups=set(), +default_public_access = DocumentAccess.build( + external_user_emails=[], + external_user_group_ids=[], + user_emails=[], + user_groups=[], is_public=True, ) diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py index da1edd8db..893dcaae0 100644 --- a/backend/scripts/query_time_check/seed_dummy_docs.py +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -78,19 +78,19 @@ def generate_dummy_chunk( for i in range(number_of_document_sets): document_set_names.append(f"Document Set {i}") - user_emails: set[str | None] = set() - user_groups: set[str] = set() - external_user_emails: set[str] = set() - external_user_group_ids: set[str] = set() + user_emails: list[str | None] = [] + user_groups: list[str] = [] + external_user_emails: list[str] = [] + external_user_group_ids: list[str] = [] for i in range(number_of_acl_entries): - user_emails.add(f"user_{i}@example.com") - user_groups.add(f"group_{i}") - external_user_emails.add(f"external_user_{i}@example.com") - external_user_group_ids.add(f"external_group_{i}") + user_emails.append(f"user_{i}@example.com") + user_groups.append(f"group_{i}") + external_user_emails.append(f"external_user_{i}@example.com") + external_user_group_ids.append(f"external_group_{i}") return DocMetadataAwareIndexChunk.from_index_chunk( index_chunk=chunk, - access=DocumentAccess( + access=DocumentAccess.build( user_emails=user_emails, user_groups=user_groups, external_user_emails=external_user_emails, diff --git a/backend/tests/integration/common_utils/constants.py b/backend/tests/integration/common_utils/constants.py index c6731e739..2a5f338b3 100644 --- a/backend/tests/integration/common_utils/constants.py +++ b/backend/tests/integration/common_utils/constants.py @@ -6,7 +6,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http" API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost" API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080" API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}" -MAX_DELAY = 45 +MAX_DELAY = 60 GENERAL_HEADERS = {"Content-Type": "application/json"} diff --git a/backend/tests/integration/tests/connector/test_connector_deletion.py b/backend/tests/integration/tests/connector/test_connector_deletion.py index dc685e89f..4c7cd320c 100644 --- a/backend/tests/integration/tests/connector/test_connector_deletion.py +++ b/backend/tests/integration/tests/connector/test_connector_deletion.py @@ -5,6 +5,7 @@ This file contains tests for the following: - updates the document sets and user groups to remove the connector - Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected """ +import os from uuid import uuid4 from sqlalchemy.orm import Session @@ -32,6 +33,13 @@ from tests.integration.common_utils.vespa import vespa_fixture def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: + user_group_1: DATestUserGroup + user_group_2: DATestUserGroup + + is_ee = ( + os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true" + ) + # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") # create api key @@ -78,16 +86,17 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: print("Document sets created and synced") - # create user groups - user_group_1: DATestUserGroup = UserGroupManager.create( - cc_pair_ids=[cc_pair_1.id], - user_performing_action=admin_user, - ) - user_group_2: DATestUserGroup = UserGroupManager.create( - cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], - user_performing_action=admin_user, - ) - UserGroupManager.wait_for_sync(user_performing_action=admin_user) + if is_ee: + # create user groups + user_group_1 = UserGroupManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + user_group_2 = UserGroupManager.create( + cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync(user_performing_action=admin_user) # inject a finished index attempt and index attempt error (exercises foreign key errors) with Session(get_sqlalchemy_engine()) as db_session: @@ -147,12 +156,13 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: ) # Update local records to match the database for later comparison - user_group_1.cc_pair_ids = [] - user_group_2.cc_pair_ids = [cc_pair_2.id] doc_set_1.cc_pair_ids = [] doc_set_2.cc_pair_ids = [cc_pair_2.id] cc_pair_1.groups = [] - cc_pair_2.groups = [user_group_2.id] + if is_ee: + cc_pair_2.groups = [user_group_2.id] + else: + cc_pair_2.groups = [] CCPairManager.wait_for_deletion_completion( cc_pair_id=cc_pair_1.id, user_performing_action=admin_user @@ -168,11 +178,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: verify_deleted=True, ) + cc_pair_2_group_name_expected = [] + if is_ee: + cc_pair_2_group_name_expected = [user_group_2.name] + DocumentManager.verify( vespa_client=vespa_client, cc_pair=cc_pair_2, doc_set_names=[doc_set_2.name], - group_names=[user_group_2.name], + group_names=cc_pair_2_group_name_expected, doc_creating_user=admin_user, verify_deleted=False, ) @@ -193,15 +207,19 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: user_performing_action=admin_user, ) - # validate user groups - UserGroupManager.verify( - user_group=user_group_1, - user_performing_action=admin_user, - ) - UserGroupManager.verify( - user_group=user_group_2, - user_performing_action=admin_user, - ) + if is_ee: + user_group_1.cc_pair_ids = [] + user_group_2.cc_pair_ids = [cc_pair_2.id] + + # validate user groups + UserGroupManager.verify( + user_group=user_group_1, + user_performing_action=admin_user, + ) + UserGroupManager.verify( + user_group=user_group_2, + user_performing_action=admin_user, + ) def test_connector_deletion_for_overlapping_connectors( @@ -210,6 +228,13 @@ def test_connector_deletion_for_overlapping_connectors( """Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors. """ + user_group_1: DATestUserGroup + user_group_2: DATestUserGroup + + is_ee = ( + os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true" + ) + # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") # create api key @@ -281,47 +306,48 @@ def test_connector_deletion_for_overlapping_connectors( doc_creating_user=admin_user, ) - # create a user group and attach it to connector 1 - user_group_1: DATestUserGroup = UserGroupManager.create( - name="Test User Group 1", - cc_pair_ids=[cc_pair_1.id], - user_performing_action=admin_user, - ) - UserGroupManager.wait_for_sync( - user_groups_to_check=[user_group_1], - user_performing_action=admin_user, - ) - cc_pair_1.groups = [user_group_1.id] + if is_ee: + # create a user group and attach it to connector 1 + user_group_1 = UserGroupManager.create( + name="Test User Group 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], + user_performing_action=admin_user, + ) + cc_pair_1.groups = [user_group_1.id] - print("User group 1 created and synced") + print("User group 1 created and synced") - # create a user group and attach it to connector 2 - user_group_2: DATestUserGroup = UserGroupManager.create( - name="Test User Group 2", - cc_pair_ids=[cc_pair_2.id], - user_performing_action=admin_user, - ) - UserGroupManager.wait_for_sync( - user_groups_to_check=[user_group_2], - user_performing_action=admin_user, - ) - cc_pair_2.groups = [user_group_2.id] + # create a user group and attach it to connector 2 + user_group_2 = UserGroupManager.create( + name="Test User Group 2", + cc_pair_ids=[cc_pair_2.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_2], + user_performing_action=admin_user, + ) + cc_pair_2.groups = [user_group_2.id] - print("User group 2 created and synced") + print("User group 2 created and synced") - # verify vespa document is in the user group - DocumentManager.verify( - vespa_client=vespa_client, - cc_pair=cc_pair_1, - group_names=[user_group_1.name, user_group_2.name], - doc_creating_user=admin_user, - ) - DocumentManager.verify( - vespa_client=vespa_client, - cc_pair=cc_pair_2, - group_names=[user_group_1.name, user_group_2.name], - doc_creating_user=admin_user, - ) + # verify vespa document is in the user group + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + group_names=[user_group_1.name, user_group_2.name], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + group_names=[user_group_1.name, user_group_2.name], + doc_creating_user=admin_user, + ) # delete connector 1 CCPairManager.pause_cc_pair( @@ -354,11 +380,15 @@ def test_connector_deletion_for_overlapping_connectors( # verify the document is not in any document sets # verify the document is only in user group 2 + group_names_expected = [] + if is_ee: + group_names_expected = [user_group_2.name] + DocumentManager.verify( vespa_client=vespa_client, cc_pair=cc_pair_2, doc_set_names=[], - group_names=[user_group_2.name], + group_names=group_names_expected, doc_creating_user=admin_user, verify_deleted=False, )