Perm sync behavior change (#3262)

* Change external permissions behavior

* fixed behavior

* added error handling

* LLM the goat

* comment

* simplify

* fixed

* done

* limits increased

* added a ton of logging

* uhhhh
This commit is contained in:
hagen-danswer
2024-11-27 12:04:15 -08:00
committed by GitHub
parent 9c0cc94f15
commit 09d3e47c03
10 changed files with 170 additions and 24 deletions

View File

@@ -241,9 +241,11 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None: if doc_sync_func is None:
raise ValueError(f"No doc sync func found for {source_type}") raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}") logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
payload = RedisConnectorPermissionSyncData( payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc), started=datetime.now(timezone.utc),

View File

@@ -49,7 +49,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if cc_pair.access_type != AccessType.SYNC: if cc_pair.access_type != AccessType.SYNC:
return False return False
# skip pruning if not active # skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False return False

View File

@@ -51,7 +51,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group", "restrictions.read.restrictions.group",
] ]
_SLIM_DOC_BATCH_SIZE = 1000 _SLIM_DOC_BATCH_SIZE = 5000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
@@ -301,5 +301,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=perm_sync_data, perm_sync_data=perm_sync_data,
) )
) )
yield doc_metadata_list if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
doc_metadata_list = [] yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
yield doc_metadata_list

View File

@@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call) return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 100 _DEFAULT_PAGINATION_LIMIT = 1000
class OnyxConfluence(Confluence): class OnyxConfluence(Confluence):

View File

@@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None:
def _relate_groups_to_cc_pair__no_commit( def _relate_groups_to_cc_pair__no_commit(
db_session: Session, db_session: Session,
cc_pair_id: int, cc_pair_id: int,
user_group_ids: list[int], user_group_ids: list[int] | None = None,
) -> None: ) -> None:
if not user_group_ids:
return
for group_id in user_group_ids: for group_id in user_group_ids:
db_session.add( db_session.add(
UserGroup__ConnectorCredentialPair( UserGroup__ConnectorCredentialPair(
@@ -402,12 +405,11 @@ def add_credential_to_connector(
db_session.flush() # make sure the association has an id db_session.flush() # make sure the association has an id
db_session.refresh(association) db_session.refresh(association)
if groups and access_type != AccessType.SYNC: _relate_groups_to_cc_pair__no_commit(
_relate_groups_to_cc_pair__no_commit( db_session=db_session,
db_session=db_session, cc_pair_id=association.id,
cc_pair_id=association.id, user_group_ids=groups,
user_group_ids=groups, )
)
db_session.commit() db_session.commit()

View File

@@ -11,6 +11,7 @@ from sqlalchemy import update
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential__UserGroup from danswer.db.models import Credential__UserGroup
@@ -298,6 +299,11 @@ def fetch_user_groups_for_documents(
db_session: Session, db_session: Session,
document_ids: list[str], document_ids: list[str],
) -> Sequence[tuple[str, list[str]]]: ) -> Sequence[tuple[str, list[str]]]:
"""
Fetches all user groups that have access to the given documents.
NOTE: this doesn't include groups if the cc_pair is access type SYNC
"""
stmt = ( stmt = (
select(Document.id, func.array_agg(UserGroup.name)) select(Document.id, func.array_agg(UserGroup.name))
.join( .join(
@@ -306,7 +312,11 @@ def fetch_user_groups_for_documents(
) )
.join( .join(
ConnectorCredentialPair, ConnectorCredentialPair,
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id, and_(
ConnectorCredentialPair.id
== UserGroup__ConnectorCredentialPair.cc_pair_id,
ConnectorCredentialPair.access_type != AccessType.SYNC,
),
) )
.join( .join(
DocumentByConnectorCredentialPair, DocumentByConnectorCredentialPair,

View File

@@ -97,6 +97,7 @@ def _get_space_permissions(
confluence_client: OnyxConfluence, confluence_client: OnyxConfluence,
is_cloud: bool, is_cloud: bool,
) -> dict[str, ExternalAccess]: ) -> dict[str, ExternalAccess]:
logger.debug("Getting space permissions")
# Gets all the spaces in the Confluence instance # Gets all the spaces in the Confluence instance
all_space_keys = [] all_space_keys = []
start = 0 start = 0
@@ -113,6 +114,7 @@ def _get_space_permissions(
start += len(spaces_batch.get("results", [])) start += len(spaces_batch.get("results", []))
# Gets the permissions for each space # Gets the permissions for each space
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
space_permissions_by_space_key: dict[str, ExternalAccess] = {} space_permissions_by_space_key: dict[str, ExternalAccess] = {}
for space_key in all_space_keys: for space_key in all_space_keys:
if is_cloud: if is_cloud:
@@ -242,6 +244,7 @@ def _fetch_all_page_restrictions_for_space(
logger.warning(f"No permissions found for document {slim_doc.id}") logger.warning(f"No permissions found for document {slim_doc.id}")
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions return document_restrictions
@@ -254,27 +257,28 @@ def confluence_doc_sync(
it in postgres so that when it gets created later, the permissions are it in postgres so that when it gets created later, the permissions are
already populated already populated
""" """
logger.debug("Starting confluence doc sync")
confluence_connector = ConfluenceConnector( confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config **cc_pair.connector.connector_specific_config
) )
confluence_connector.load_credentials(cc_pair.credential.credential_json) confluence_connector.load_credentials(cc_pair.credential.credential_json)
if confluence_connector.confluence_client is None:
raise ValueError("Failed to load credentials")
confluence_client = confluence_connector.confluence_client
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
space_permissions_by_space_key = _get_space_permissions( space_permissions_by_space_key = _get_space_permissions(
confluence_client=confluence_client, confluence_client=confluence_connector.confluence_client,
is_cloud=is_cloud, is_cloud=is_cloud,
) )
slim_docs = [] slim_docs = []
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents(): for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
slim_docs.extend(doc_batch) slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
return _fetch_all_page_restrictions_for_space( return _fetch_all_page_restrictions_for_space(
confluence_client=confluence_client, confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs, slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key, space_permissions_by_space_key=space_permissions_by_space_key,
) )

View File

@@ -14,7 +14,10 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]: ) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {} group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval(): for user_result in confluence_client.paginated_cql_user_retrieval():
user = user_result["user"] user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
continue
email = user.get("email") email = user.get("email")
if not email: if not email:
# This field is only present in Confluence Server # This field is only present in Confluence Server

View File

@@ -57,9 +57,9 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# If nothing is specified here, we run the doc_sync every time the celery beat runs # If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = { EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 60 seconds # Polling is not supported so we fetch all group permissions every 5 minutes
DocumentSource.GOOGLE_DRIVE: 60, DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.CONFLUENCE: 60, DocumentSource.CONFLUENCE: 5 * 60,
} }

View File

@@ -14,6 +14,7 @@ from tests.integration.common_utils.managers.document_search import (
) )
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestCCPair from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestConnector from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential from tests.integration.common_utils.test_models import DATestCredential
@@ -215,3 +216,124 @@ def test_slack_permission_sync(
# Ensure test_user_1 can only see messages from the public channel # Ensure test_user_1 can only see messages from the public channel
assert public_message in danswer_doc_message_strings assert public_message in danswer_doc_message_strings
assert private_message not in danswer_doc_message_strings assert private_message not in danswer_doc_message_strings
def test_slack_group_permission_sync(
reset: None,
vespa_client: vespa_fixture,
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
) -> None:
"""
This test ensures that permission sync overrides danswer group access.
"""
public_channel, private_channel = slack_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@onyx-test.com",
)
# Create a user group and adding the non-admin user to it
user_group = UserGroupManager.create(
name="test_group",
user_ids=[test_user_1.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group],
user_performing_action=admin_user,
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
LLMProviderManager.create(user_performing_action=admin_user)
# Add only admin to the private channel
SlackManager.set_channel_members(
slack_client=slack_client,
admin_user_id=admin_user_id,
channel=private_channel,
user_ids=[admin_user_id],
)
before = datetime.now(timezone.utc)
credential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
},
user_performing_action=admin_user,
)
# Create connector with sync access and assign it to the user group
connector = ConnectorManager.create(
name="Slack",
input_type=InputType.POLL,
source=DocumentSource.SLACK,
connector_specific_config={
"workspace": "onyx-test-workspace",
"channels": [private_channel["name"]],
},
access_type=AccessType.SYNC,
groups=[user_group.id],
user_performing_action=admin_user,
)
cc_pair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.SYNC,
user_performing_action=admin_user,
groups=[user_group.id],
)
# Add a test message to the private channel
private_message = "This is a secret message: 987654"
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
message=private_message,
)
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# Run permission sync
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
)
# Verify admin can see the message
admin_docs = DocumentSearchManager.search_documents(
query="secret message",
user_performing_action=admin_user,
)
assert private_message in admin_docs
# Verify test_user_1 cannot see the message despite being in the group
# (Slack permissions should take precedence)
user_1_docs = DocumentSearchManager.search_documents(
query="secret message",
user_performing_action=test_user_1,
)
assert private_message not in user_1_docs