mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
Perm sync behavior change (#3262)
* Change external permissions behavior * fixed behavior * added error handling * LLM the goat * comment * simplify * fixed * done * limits increased * added a ton of logging * uhhhh
This commit is contained in:
@@ -241,9 +241,11 @@ def connector_permission_sync_generator_task(
|
|||||||
|
|
||||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||||
if doc_sync_func is None:
|
if doc_sync_func is None:
|
||||||
raise ValueError(f"No doc sync func found for {source_type}")
|
raise ValueError(
|
||||||
|
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Syncing docs for {source_type}")
|
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
|
||||||
|
|
||||||
payload = RedisConnectorPermissionSyncData(
|
payload = RedisConnectorPermissionSyncData(
|
||||||
started=datetime.now(timezone.utc),
|
started=datetime.now(timezone.utc),
|
||||||
|
@@ -49,7 +49,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
|||||||
if cc_pair.access_type != AccessType.SYNC:
|
if cc_pair.access_type != AccessType.SYNC:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# skip pruning if not active
|
# skip external group sync if not active
|
||||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@@ -51,7 +51,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
|||||||
"restrictions.read.restrictions.group",
|
"restrictions.read.restrictions.group",
|
||||||
]
|
]
|
||||||
|
|
||||||
_SLIM_DOC_BATCH_SIZE = 1000
|
_SLIM_DOC_BATCH_SIZE = 5000
|
||||||
|
|
||||||
|
|
||||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||||
@@ -301,5 +301,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
|||||||
perm_sync_data=perm_sync_data,
|
perm_sync_data=perm_sync_data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield doc_metadata_list
|
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||||
doc_metadata_list = []
|
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||||
|
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||||
|
|
||||||
|
yield doc_metadata_list
|
||||||
|
@@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
|||||||
return cast(F, wrapped_call)
|
return cast(F, wrapped_call)
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_PAGINATION_LIMIT = 100
|
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||||
|
|
||||||
|
|
||||||
class OnyxConfluence(Confluence):
|
class OnyxConfluence(Confluence):
|
||||||
|
@@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
|||||||
def _relate_groups_to_cc_pair__no_commit(
|
def _relate_groups_to_cc_pair__no_commit(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
cc_pair_id: int,
|
cc_pair_id: int,
|
||||||
user_group_ids: list[int],
|
user_group_ids: list[int] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if not user_group_ids:
|
||||||
|
return
|
||||||
|
|
||||||
for group_id in user_group_ids:
|
for group_id in user_group_ids:
|
||||||
db_session.add(
|
db_session.add(
|
||||||
UserGroup__ConnectorCredentialPair(
|
UserGroup__ConnectorCredentialPair(
|
||||||
@@ -402,12 +405,11 @@ def add_credential_to_connector(
|
|||||||
db_session.flush() # make sure the association has an id
|
db_session.flush() # make sure the association has an id
|
||||||
db_session.refresh(association)
|
db_session.refresh(association)
|
||||||
|
|
||||||
if groups and access_type != AccessType.SYNC:
|
_relate_groups_to_cc_pair__no_commit(
|
||||||
_relate_groups_to_cc_pair__no_commit(
|
db_session=db_session,
|
||||||
db_session=db_session,
|
cc_pair_id=association.id,
|
||||||
cc_pair_id=association.id,
|
user_group_ids=groups,
|
||||||
user_group_ids=groups,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
@@ -11,6 +11,7 @@ from sqlalchemy import update
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
|
from danswer.db.enums import AccessType
|
||||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
from danswer.db.models import ConnectorCredentialPair
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
from danswer.db.models import Credential__UserGroup
|
from danswer.db.models import Credential__UserGroup
|
||||||
@@ -298,6 +299,11 @@ def fetch_user_groups_for_documents(
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
document_ids: list[str],
|
document_ids: list[str],
|
||||||
) -> Sequence[tuple[str, list[str]]]:
|
) -> Sequence[tuple[str, list[str]]]:
|
||||||
|
"""
|
||||||
|
Fetches all user groups that have access to the given documents.
|
||||||
|
|
||||||
|
NOTE: this doesn't include groups if the cc_pair is access type SYNC
|
||||||
|
"""
|
||||||
stmt = (
|
stmt = (
|
||||||
select(Document.id, func.array_agg(UserGroup.name))
|
select(Document.id, func.array_agg(UserGroup.name))
|
||||||
.join(
|
.join(
|
||||||
@@ -306,7 +312,11 @@ def fetch_user_groups_for_documents(
|
|||||||
)
|
)
|
||||||
.join(
|
.join(
|
||||||
ConnectorCredentialPair,
|
ConnectorCredentialPair,
|
||||||
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id,
|
and_(
|
||||||
|
ConnectorCredentialPair.id
|
||||||
|
== UserGroup__ConnectorCredentialPair.cc_pair_id,
|
||||||
|
ConnectorCredentialPair.access_type != AccessType.SYNC,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
.join(
|
.join(
|
||||||
DocumentByConnectorCredentialPair,
|
DocumentByConnectorCredentialPair,
|
||||||
|
@@ -97,6 +97,7 @@ def _get_space_permissions(
|
|||||||
confluence_client: OnyxConfluence,
|
confluence_client: OnyxConfluence,
|
||||||
is_cloud: bool,
|
is_cloud: bool,
|
||||||
) -> dict[str, ExternalAccess]:
|
) -> dict[str, ExternalAccess]:
|
||||||
|
logger.debug("Getting space permissions")
|
||||||
# Gets all the spaces in the Confluence instance
|
# Gets all the spaces in the Confluence instance
|
||||||
all_space_keys = []
|
all_space_keys = []
|
||||||
start = 0
|
start = 0
|
||||||
@@ -113,6 +114,7 @@ def _get_space_permissions(
|
|||||||
start += len(spaces_batch.get("results", []))
|
start += len(spaces_batch.get("results", []))
|
||||||
|
|
||||||
# Gets the permissions for each space
|
# Gets the permissions for each space
|
||||||
|
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||||
for space_key in all_space_keys:
|
for space_key in all_space_keys:
|
||||||
if is_cloud:
|
if is_cloud:
|
||||||
@@ -242,6 +244,7 @@ def _fetch_all_page_restrictions_for_space(
|
|||||||
|
|
||||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||||
|
|
||||||
|
logger.debug("Finished fetching all page restrictions for space")
|
||||||
return document_restrictions
|
return document_restrictions
|
||||||
|
|
||||||
|
|
||||||
@@ -254,27 +257,28 @@ def confluence_doc_sync(
|
|||||||
it in postgres so that when it gets created later, the permissions are
|
it in postgres so that when it gets created later, the permissions are
|
||||||
already populated
|
already populated
|
||||||
"""
|
"""
|
||||||
|
logger.debug("Starting confluence doc sync")
|
||||||
confluence_connector = ConfluenceConnector(
|
confluence_connector = ConfluenceConnector(
|
||||||
**cc_pair.connector.connector_specific_config
|
**cc_pair.connector.connector_specific_config
|
||||||
)
|
)
|
||||||
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
||||||
if confluence_connector.confluence_client is None:
|
|
||||||
raise ValueError("Failed to load credentials")
|
|
||||||
confluence_client = confluence_connector.confluence_client
|
|
||||||
|
|
||||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||||
|
|
||||||
space_permissions_by_space_key = _get_space_permissions(
|
space_permissions_by_space_key = _get_space_permissions(
|
||||||
confluence_client=confluence_client,
|
confluence_client=confluence_connector.confluence_client,
|
||||||
is_cloud=is_cloud,
|
is_cloud=is_cloud,
|
||||||
)
|
)
|
||||||
|
|
||||||
slim_docs = []
|
slim_docs = []
|
||||||
|
logger.debug("Fetching all slim documents from confluence")
|
||||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||||
|
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||||
slim_docs.extend(doc_batch)
|
slim_docs.extend(doc_batch)
|
||||||
|
|
||||||
|
logger.debug("Fetching all page restrictions for space")
|
||||||
return _fetch_all_page_restrictions_for_space(
|
return _fetch_all_page_restrictions_for_space(
|
||||||
confluence_client=confluence_client,
|
confluence_client=confluence_connector.confluence_client,
|
||||||
slim_docs=slim_docs,
|
slim_docs=slim_docs,
|
||||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||||
)
|
)
|
||||||
|
@@ -14,7 +14,10 @@ def _build_group_member_email_map(
|
|||||||
) -> dict[str, set[str]]:
|
) -> dict[str, set[str]]:
|
||||||
group_member_emails: dict[str, set[str]] = {}
|
group_member_emails: dict[str, set[str]] = {}
|
||||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||||
user = user_result["user"]
|
user = user_result.get("user", {})
|
||||||
|
if not user:
|
||||||
|
logger.warning(f"user result missing user field: {user_result}")
|
||||||
|
continue
|
||||||
email = user.get("email")
|
email = user.get("email")
|
||||||
if not email:
|
if not email:
|
||||||
# This field is only present in Confluence Server
|
# This field is only present in Confluence Server
|
||||||
|
@@ -57,9 +57,9 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
|||||||
|
|
||||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||||
# Polling is not supported so we fetch all group permissions every 60 seconds
|
# Polling is not supported so we fetch all group permissions every 5 minutes
|
||||||
DocumentSource.GOOGLE_DRIVE: 60,
|
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||||
DocumentSource.CONFLUENCE: 60,
|
DocumentSource.CONFLUENCE: 5 * 60,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -14,6 +14,7 @@ from tests.integration.common_utils.managers.document_search import (
|
|||||||
)
|
)
|
||||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||||
from tests.integration.common_utils.managers.user import UserManager
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||||
from tests.integration.common_utils.test_models import DATestCCPair
|
from tests.integration.common_utils.test_models import DATestCCPair
|
||||||
from tests.integration.common_utils.test_models import DATestConnector
|
from tests.integration.common_utils.test_models import DATestConnector
|
||||||
from tests.integration.common_utils.test_models import DATestCredential
|
from tests.integration.common_utils.test_models import DATestCredential
|
||||||
@@ -215,3 +216,124 @@ def test_slack_permission_sync(
|
|||||||
# Ensure test_user_1 can only see messages from the public channel
|
# Ensure test_user_1 can only see messages from the public channel
|
||||||
assert public_message in danswer_doc_message_strings
|
assert public_message in danswer_doc_message_strings
|
||||||
assert private_message not in danswer_doc_message_strings
|
assert private_message not in danswer_doc_message_strings
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_group_permission_sync(
|
||||||
|
reset: None,
|
||||||
|
vespa_client: vespa_fixture,
|
||||||
|
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
This test ensures that permission sync overrides danswer group access.
|
||||||
|
"""
|
||||||
|
public_channel, private_channel = slack_test_setup
|
||||||
|
|
||||||
|
# Creating an admin user (first user created is automatically an admin)
|
||||||
|
admin_user: DATestUser = UserManager.create(
|
||||||
|
email="admin@onyx-test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Creating a non-admin user
|
||||||
|
test_user_1: DATestUser = UserManager.create(
|
||||||
|
email="test_user_1@onyx-test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a user group and adding the non-admin user to it
|
||||||
|
user_group = UserGroupManager.create(
|
||||||
|
name="test_group",
|
||||||
|
user_ids=[test_user_1.id],
|
||||||
|
cc_pair_ids=[],
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
UserGroupManager.wait_for_sync(
|
||||||
|
user_groups_to_check=[user_group],
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||||
|
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||||
|
admin_user_id = email_id_map[admin_user.email]
|
||||||
|
|
||||||
|
LLMProviderManager.create(user_performing_action=admin_user)
|
||||||
|
|
||||||
|
# Add only admin to the private channel
|
||||||
|
SlackManager.set_channel_members(
|
||||||
|
slack_client=slack_client,
|
||||||
|
admin_user_id=admin_user_id,
|
||||||
|
channel=private_channel,
|
||||||
|
user_ids=[admin_user_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
before = datetime.now(timezone.utc)
|
||||||
|
credential = CredentialManager.create(
|
||||||
|
source=DocumentSource.SLACK,
|
||||||
|
credential_json={
|
||||||
|
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||||
|
},
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create connector with sync access and assign it to the user group
|
||||||
|
connector = ConnectorManager.create(
|
||||||
|
name="Slack",
|
||||||
|
input_type=InputType.POLL,
|
||||||
|
source=DocumentSource.SLACK,
|
||||||
|
connector_specific_config={
|
||||||
|
"workspace": "onyx-test-workspace",
|
||||||
|
"channels": [private_channel["name"]],
|
||||||
|
},
|
||||||
|
access_type=AccessType.SYNC,
|
||||||
|
groups=[user_group.id],
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_pair = CCPairManager.create(
|
||||||
|
credential_id=credential.id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
access_type=AccessType.SYNC,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
groups=[user_group.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a test message to the private channel
|
||||||
|
private_message = "This is a secret message: 987654"
|
||||||
|
SlackManager.add_message_to_channel(
|
||||||
|
slack_client=slack_client,
|
||||||
|
channel=private_channel,
|
||||||
|
message=private_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run indexing
|
||||||
|
CCPairManager.run_once(cc_pair, admin_user)
|
||||||
|
CCPairManager.wait_for_indexing(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=before,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run permission sync
|
||||||
|
CCPairManager.sync(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
CCPairManager.wait_for_sync(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=before,
|
||||||
|
number_of_updated_docs=1,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify admin can see the message
|
||||||
|
admin_docs = DocumentSearchManager.search_documents(
|
||||||
|
query="secret message",
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
assert private_message in admin_docs
|
||||||
|
|
||||||
|
# Verify test_user_1 cannot see the message despite being in the group
|
||||||
|
# (Slack permissions should take precedence)
|
||||||
|
user_1_docs = DocumentSearchManager.search_documents(
|
||||||
|
query="secret message",
|
||||||
|
user_performing_action=test_user_1,
|
||||||
|
)
|
||||||
|
assert private_message not in user_1_docs
|
||||||
|
Reference in New Issue
Block a user