mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-01 18:20:49 +02:00
Perm sync behavior change (#3262)
* Change external permissions behavior * fixed behavior * added error handling * LLM the goat * comment * simplify * fixed * done * limits increased * added a ton of logging * uhhhh
This commit is contained in:
@ -241,9 +241,11 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(f"No doc sync func found for {source_type}")
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
|
||||
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=datetime.now(timezone.utc),
|
||||
|
@ -49,7 +49,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
# skip external group sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
|
@ -51,7 +51,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 1000
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
@ -301,5 +301,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
|
||||
yield doc_metadata_list
|
||||
|
@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 100
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
|
@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
def _relate_groups_to_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
user_group_ids: list[int],
|
||||
user_group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
if not user_group_ids:
|
||||
return
|
||||
|
||||
for group_id in user_group_ids:
|
||||
db_session.add(
|
||||
UserGroup__ConnectorCredentialPair(
|
||||
@ -402,12 +405,11 @@ def add_credential_to_connector(
|
||||
db_session.flush() # make sure the association has an id
|
||||
db_session.refresh(association)
|
||||
|
||||
if groups and access_type != AccessType.SYNC:
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
user_group_ids=groups,
|
||||
)
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
user_group_ids=groups,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
@ -11,6 +11,7 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential__UserGroup
|
||||
@ -298,6 +299,11 @@ def fetch_user_groups_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, list[str]]]:
|
||||
"""
|
||||
Fetches all user groups that have access to the given documents.
|
||||
|
||||
NOTE: this doesn't include groups if the cc_pair is access type SYNC
|
||||
"""
|
||||
stmt = (
|
||||
select(Document.id, func.array_agg(UserGroup.name))
|
||||
.join(
|
||||
@ -306,7 +312,11 @@ def fetch_user_groups_for_documents(
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id,
|
||||
and_(
|
||||
ConnectorCredentialPair.id
|
||||
== UserGroup__ConnectorCredentialPair.cc_pair_id,
|
||||
ConnectorCredentialPair.access_type != AccessType.SYNC,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
|
@ -97,6 +97,7 @@ def _get_space_permissions(
|
||||
confluence_client: OnyxConfluence,
|
||||
is_cloud: bool,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
@ -113,6 +114,7 @@ def _get_space_permissions(
|
||||
start += len(spaces_batch.get("results", []))
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||
for space_key in all_space_keys:
|
||||
if is_cloud:
|
||||
@ -242,6 +244,7 @@ def _fetch_all_page_restrictions_for_space(
|
||||
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
|
||||
logger.debug("Finished fetching all page restrictions for space")
|
||||
return document_restrictions
|
||||
|
||||
|
||||
@ -254,27 +257,28 @@ def confluence_doc_sync(
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
logger.debug("Starting confluence doc sync")
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
if confluence_connector.confluence_client is None:
|
||||
raise ValueError("Failed to load credentials")
|
||||
confluence_client = confluence_connector.confluence_client
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
space_permissions_by_space_key = _get_space_permissions(
|
||||
confluence_client=confluence_client,
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
is_cloud=is_cloud,
|
||||
)
|
||||
|
||||
slim_docs = []
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
return _fetch_all_page_restrictions_for_space(
|
||||
confluence_client=confluence_client,
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
)
|
||||
|
@ -14,7 +14,10 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
user = user_result["user"]
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
continue
|
||||
email = user.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
|
@ -57,9 +57,9 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 60 seconds
|
||||
DocumentSource.GOOGLE_DRIVE: 60,
|
||||
DocumentSource.CONFLUENCE: 60,
|
||||
# Polling is not supported so we fetch all group permissions every 5 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
}
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@ from tests.integration.common_utils.managers.document_search import (
|
||||
)
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
@ -215,3 +216,124 @@ def test_slack_permission_sync(
|
||||
# Ensure test_user_1 can only see messages from the public channel
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message not in danswer_doc_message_strings
|
||||
|
||||
|
||||
def test_slack_group_permission_sync(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
This test ensures that permission sync overrides danswer group access.
|
||||
"""
|
||||
public_channel, private_channel = slack_test_setup
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(
|
||||
email="admin@onyx-test.com",
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_1: DATestUser = UserManager.create(
|
||||
email="test_user_1@onyx-test.com",
|
||||
)
|
||||
|
||||
# Create a user group and adding the non-admin user to it
|
||||
user_group = UserGroupManager.create(
|
||||
name="test_group",
|
||||
user_ids=[test_user_1.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
admin_user_id = email_id_map[admin_user.email]
|
||||
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
# Add only admin to the private channel
|
||||
SlackManager.set_channel_members(
|
||||
slack_client=slack_client,
|
||||
admin_user_id=admin_user_id,
|
||||
channel=private_channel,
|
||||
user_ids=[admin_user_id],
|
||||
)
|
||||
|
||||
before = datetime.now(timezone.utc)
|
||||
credential = CredentialManager.create(
|
||||
source=DocumentSource.SLACK,
|
||||
credential_json={
|
||||
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create connector with sync access and assign it to the user group
|
||||
connector = ConnectorManager.create(
|
||||
name="Slack",
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"workspace": "onyx-test-workspace",
|
||||
"channels": [private_channel["name"]],
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
groups=[user_group.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
cc_pair = CCPairManager.create(
|
||||
credential_id=credential.id,
|
||||
connector_id=connector.id,
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
groups=[user_group.id],
|
||||
)
|
||||
|
||||
# Add a test message to the private channel
|
||||
private_message = "This is a secret message: 987654"
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=private_channel,
|
||||
message=private_message,
|
||||
)
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify admin can see the message
|
||||
admin_docs = DocumentSearchManager.search_documents(
|
||||
query="secret message",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert private_message in admin_docs
|
||||
|
||||
# Verify test_user_1 cannot see the message despite being in the group
|
||||
# (Slack permissions should take precedence)
|
||||
user_1_docs = DocumentSearchManager.search_documents(
|
||||
query="secret message",
|
||||
user_performing_action=test_user_1,
|
||||
)
|
||||
assert private_message not in user_1_docs
|
||||
|
Reference in New Issue
Block a user