Perm sync behavior change (#3262)

* Change external permissions behavior

* fixed behavior

* added error handling

* LLM the goat

* comment

* simplify

* fixed

* done

* limits increased

* added a ton of logging

* uhhhh
This commit is contained in:
hagen-danswer
2024-11-27 12:04:15 -08:00
committed by GitHub
parent 9c0cc94f15
commit 09d3e47c03
10 changed files with 170 additions and 24 deletions

View File

@ -241,9 +241,11 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(f"No doc sync func found for {source_type}")
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),

View File

@ -49,7 +49,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
# skip pruning if not active
# skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False

View File

@ -51,7 +51,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 1000
_SLIM_DOC_BATCH_SIZE = 5000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
@ -301,5 +301,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
yield doc_metadata_list

View File

@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 100
_DEFAULT_PAGINATION_LIMIT = 1000
class OnyxConfluence(Confluence):

View File

@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None:
def _relate_groups_to_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
user_group_ids: list[int],
user_group_ids: list[int] | None = None,
) -> None:
if not user_group_ids:
return
for group_id in user_group_ids:
db_session.add(
UserGroup__ConnectorCredentialPair(
@ -402,12 +405,11 @@ def add_credential_to_connector(
db_session.flush() # make sure the association has an id
db_session.refresh(association)
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
db_session.commit()

View File

@ -11,6 +11,7 @@ from sqlalchemy import update
from sqlalchemy.orm import Session
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential__UserGroup
@ -298,6 +299,11 @@ def fetch_user_groups_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, list[str]]]:
"""
Fetches all user groups that have access to the given documents.
NOTE: this doesn't include groups if the cc_pair is access type SYNC
"""
stmt = (
select(Document.id, func.array_agg(UserGroup.name))
.join(
@ -306,7 +312,11 @@ def fetch_user_groups_for_documents(
)
.join(
ConnectorCredentialPair,
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id,
and_(
ConnectorCredentialPair.id
== UserGroup__ConnectorCredentialPair.cc_pair_id,
ConnectorCredentialPair.access_type != AccessType.SYNC,
),
)
.join(
DocumentByConnectorCredentialPair,

View File

@ -97,6 +97,7 @@ def _get_space_permissions(
confluence_client: OnyxConfluence,
is_cloud: bool,
) -> dict[str, ExternalAccess]:
logger.debug("Getting space permissions")
# Gets all the spaces in the Confluence instance
all_space_keys = []
start = 0
@ -113,6 +114,7 @@ def _get_space_permissions(
start += len(spaces_batch.get("results", []))
# Gets the permissions for each space
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
for space_key in all_space_keys:
if is_cloud:
@ -242,6 +244,7 @@ def _fetch_all_page_restrictions_for_space(
logger.warning(f"No permissions found for document {slim_doc.id}")
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions
@ -254,27 +257,28 @@ def confluence_doc_sync(
it in postgres so that when it gets created later, the permissions are
already populated
"""
logger.debug("Starting confluence doc sync")
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
confluence_connector.load_credentials(cc_pair.credential.credential_json)
if confluence_connector.confluence_client is None:
raise ValueError("Failed to load credentials")
confluence_client = confluence_connector.confluence_client
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
space_permissions_by_space_key = _get_space_permissions(
confluence_client=confluence_client,
confluence_client=confluence_connector.confluence_client,
is_cloud=is_cloud,
)
slim_docs = []
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
return _fetch_all_page_restrictions_for_space(
confluence_client=confluence_client,
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
)

View File

@ -14,7 +14,10 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
user = user_result["user"]
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
continue
email = user.get("email")
if not email:
# This field is only present in Confluence Server

View File

@ -57,9 +57,9 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 60 seconds
DocumentSource.GOOGLE_DRIVE: 60,
DocumentSource.CONFLUENCE: 60,
# Polling is not supported so we fetch all group permissions every 5 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.CONFLUENCE: 5 * 60,
}

View File

@ -14,6 +14,7 @@ from tests.integration.common_utils.managers.document_search import (
)
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
@ -215,3 +216,124 @@ def test_slack_permission_sync(
# Ensure test_user_1 can only see messages from the public channel
assert public_message in danswer_doc_message_strings
assert private_message not in danswer_doc_message_strings
def test_slack_group_permission_sync(
reset: None,
vespa_client: vespa_fixture,
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
) -> None:
"""
This test ensures that permission sync overrides danswer group access.
"""
public_channel, private_channel = slack_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@onyx-test.com",
)
# Create a user group and adding the non-admin user to it
user_group = UserGroupManager.create(
name="test_group",
user_ids=[test_user_1.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group],
user_performing_action=admin_user,
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
LLMProviderManager.create(user_performing_action=admin_user)
# Add only admin to the private channel
SlackManager.set_channel_members(
slack_client=slack_client,
admin_user_id=admin_user_id,
channel=private_channel,
user_ids=[admin_user_id],
)
before = datetime.now(timezone.utc)
credential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
},
user_performing_action=admin_user,
)
# Create connector with sync access and assign it to the user group
connector = ConnectorManager.create(
name="Slack",
input_type=InputType.POLL,
source=DocumentSource.SLACK,
connector_specific_config={
"workspace": "onyx-test-workspace",
"channels": [private_channel["name"]],
},
access_type=AccessType.SYNC,
groups=[user_group.id],
user_performing_action=admin_user,
)
cc_pair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.SYNC,
user_performing_action=admin_user,
groups=[user_group.id],
)
# Add a test message to the private channel
private_message = "This is a secret message: 987654"
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
message=private_message,
)
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# Run permission sync
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
)
# Verify admin can see the message
admin_docs = DocumentSearchManager.search_documents(
query="secret message",
user_performing_action=admin_user,
)
assert private_message in admin_docs
# Verify test_user_1 cannot see the message despite being in the group
# (Slack permissions should take precedence)
user_1_docs = DocumentSearchManager.search_documents(
query="secret message",
user_performing_action=test_user_1,
)
assert private_message not in user_1_docs