mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Fix group prefix (#11)
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from danswer.access.access import (
|
||||
get_access_for_documents as get_access_for_documents_without_groups,
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from ee.danswer.db.user_group import fetch_user_groups_for_documents
|
||||
@@ -14,13 +14,13 @@ from ee.danswer.db.user_group import fetch_user_groups_for_user
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
access_dict = get_access_for_documents_without_groups(
|
||||
non_ee_access_dict = get_access_for_documents_without_groups(
|
||||
document_ids=document_ids,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
user_group_info = {
|
||||
document_id: group_names
|
||||
@@ -31,36 +31,16 @@ def _get_access_for_documents(
|
||||
)
|
||||
}
|
||||
|
||||
# overload user_ids a bit - use it for both actual User IDs + group IDs
|
||||
return {
|
||||
document_id: DocumentAccess(
|
||||
user_ids=access.user_ids.union(user_group_info.get(document_id, [])), # type: ignore
|
||||
is_public=access.is_public,
|
||||
user_ids=non_ee_access.user_ids,
|
||||
user_groups=user_group_info.get(document_id, []), # type: ignore
|
||||
is_public=non_ee_access.is_public,
|
||||
)
|
||||
for document_id, access in access_dict.items()
|
||||
for document_id, non_ee_access in non_ee_access_dict.items()
|
||||
}
|
||||
|
||||
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
if db_session is None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
return _get_access_for_documents(
|
||||
document_ids, cc_pair_to_delete, db_session
|
||||
)
|
||||
|
||||
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
@@ -11,12 +11,16 @@ from danswer.db.tasks import mark_task_finished
|
||||
from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.user_group_sync import name_user_group_sync_task
|
||||
from ee.danswer.db.user_group import fetch_user_groups
|
||||
from ee.danswer.user_groups.sync import sync_user_groups
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# mark as EE for all tasks in this file
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_user_group_task(user_group_id: int) -> None:
|
||||
@@ -25,9 +29,14 @@ def sync_user_group_task(user_group_id: int) -> None:
|
||||
mark_task_start(task_name, db_session)
|
||||
|
||||
# actual sync logic
|
||||
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
|
||||
error_msg = None
|
||||
try:
|
||||
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to sync user group - {error_msg}")
|
||||
|
||||
mark_task_finished(task_name, db_session)
|
||||
mark_task_finished(task_name, db_session, success=error_msg is None)
|
||||
|
||||
|
||||
#####
|
||||
|
@@ -3,11 +3,11 @@ from sqlalchemy.orm import Session
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.access.access import get_access_for_documents
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_documents_for_user_group
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
|
Reference in New Issue
Block a user