Fix group prefix (#11)

This commit is contained in:
Chris Weaver
2023-10-27 12:13:52 -07:00
parent d016e8335e
commit db8ce61ff4
7 changed files with 51 additions and 47 deletions

View File

@@ -1,11 +1,11 @@
from sqlalchemy.orm import Session
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
from danswer.access.access import (
get_access_for_documents as get_access_for_documents_without_groups,
_get_access_for_documents as get_access_for_documents_without_groups,
)
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
from danswer.access.models import DocumentAccess
from danswer.db.engine import get_sqlalchemy_engine
from danswer.access.utils import prefix_user_group
from danswer.db.models import User
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from ee.danswer.db.user_group import fetch_user_groups_for_documents
@@ -14,13 +14,13 @@ from ee.danswer.db.user_group import fetch_user_groups_for_user
def _get_access_for_documents(
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
) -> dict[str, DocumentAccess]:
access_dict = get_access_for_documents_without_groups(
non_ee_access_dict = get_access_for_documents_without_groups(
document_ids=document_ids,
cc_pair_to_delete=cc_pair_to_delete,
db_session=db_session,
cc_pair_to_delete=cc_pair_to_delete,
)
user_group_info = {
document_id: group_names
@@ -31,36 +31,16 @@ def _get_access_for_documents(
)
}
# overload user_ids a bit - use it for both actual User IDs + group IDs
return {
document_id: DocumentAccess(
user_ids=access.user_ids.union(user_group_info.get(document_id, [])), # type: ignore
is_public=access.is_public,
user_ids=non_ee_access.user_ids,
user_groups=user_group_info.get(document_id, []), # type: ignore
is_public=non_ee_access.is_public,
)
for document_id, access in access_dict.items()
for document_id, non_ee_access in non_ee_access_dict.items()
}
def get_access_for_documents(
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
db_session: Session | None = None,
) -> dict[str, DocumentAccess]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
return _get_access_for_documents(
document_ids, cc_pair_to_delete, db_session
)
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The

View File

@@ -11,12 +11,16 @@ from danswer.db.tasks import mark_task_finished
from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.background.user_group_sync import name_user_group_sync_task
from ee.danswer.db.user_group import fetch_user_groups
from ee.danswer.user_groups.sync import sync_user_groups
logger = setup_logger()
# mark as EE for all tasks in this file
global_version.set_ee()
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_user_group_task(user_group_id: int) -> None:
@@ -25,9 +29,14 @@ def sync_user_group_task(user_group_id: int) -> None:
mark_task_start(task_name, db_session)
# actual sync logic
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
error_msg = None
try:
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
except Exception as e:
error_msg = str(e)
logger.exception(f"Failed to sync user group - {error_msg}")
mark_task_finished(task_name, db_session)
mark_task_finished(task_name, db_session, success=error_msg is None)
#####

View File

@@ -3,11 +3,11 @@ from sqlalchemy.orm import Session
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.access.access import get_access_for_documents
from danswer.db.document import prepare_to_modify_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from ee.danswer.access.access import get_access_for_documents
from ee.danswer.db.user_group import delete_user_group
from ee.danswer.db.user_group import fetch_documents_for_user_group
from ee.danswer.db.user_group import fetch_user_group