mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 11:58:34 +02:00
Fix group prefix (#11)
This commit is contained in:
parent
d016e8335e
commit
db8ce61ff4
@ -1,6 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.models import User
|
||||
@ -19,7 +20,7 @@ def _get_access_for_documents(
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(user_ids, is_public)
|
||||
document_id: DocumentAccess.build(user_ids, [], is_public)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
}
|
||||
|
||||
@ -38,12 +39,6 @@ def get_access_for_documents(
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
@ -1,20 +1,30 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess:
|
||||
user_ids: set[str] # stringified UUIDs
|
||||
user_groups: set[str] # names of user groups associated with this document
|
||||
is_public: bool
|
||||
|
||||
def to_acl(self) -> list[str]:
|
||||
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
return (
|
||||
[prefix_user(user_id) for user_id in self.user_ids]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
|
||||
def build(
|
||||
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
user_ids={str(user_id) for user_id in user_ids if user_id},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
10
backend/danswer/access/utils.py
Normal file
10
backend/danswer/access/utils.py
Normal file
@ -0,0 +1,10 @@
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
@ -3,8 +3,7 @@ from threading import Thread
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Document
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
@ -37,7 +36,7 @@ def set_acl_for_vespa(should_check_if_already_done: bool = False) -> None:
|
||||
# for all documents, set the `access_control_list` field appropriately
|
||||
# based on the state of Postgres
|
||||
documents = db_session.scalars(select(Document)).all()
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
document_access_dict = get_access_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=[document.id for document in documents],
|
||||
)
|
||||
@ -46,15 +45,16 @@ def set_acl_for_vespa(should_check_if_already_done: bool = False) -> None:
|
||||
vespa_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
if not isinstance(vespa_index, VespaIndex):
|
||||
raise ValueError("This script is only for Vespa indexes")
|
||||
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=DocumentAccess.build(user_ids, is_public),
|
||||
access=access,
|
||||
)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
for document_id, access in document_access_dict.items()
|
||||
]
|
||||
vespa_index.update(update_requests=update_requests)
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from danswer.access.access import (
|
||||
get_access_for_documents as get_access_for_documents_without_groups,
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from ee.danswer.db.user_group import fetch_user_groups_for_documents
|
||||
@ -14,13 +14,13 @@ from ee.danswer.db.user_group import fetch_user_groups_for_user
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
access_dict = get_access_for_documents_without_groups(
|
||||
non_ee_access_dict = get_access_for_documents_without_groups(
|
||||
document_ids=document_ids,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
user_group_info = {
|
||||
document_id: group_names
|
||||
@ -31,36 +31,16 @@ def _get_access_for_documents(
|
||||
)
|
||||
}
|
||||
|
||||
# overload user_ids a bit - use it for both actual User IDs + group IDs
|
||||
return {
|
||||
document_id: DocumentAccess(
|
||||
user_ids=access.user_ids.union(user_group_info.get(document_id, [])), # type: ignore
|
||||
is_public=access.is_public,
|
||||
user_ids=non_ee_access.user_ids,
|
||||
user_groups=user_group_info.get(document_id, []), # type: ignore
|
||||
is_public=non_ee_access.is_public,
|
||||
)
|
||||
for document_id, access in access_dict.items()
|
||||
for document_id, non_ee_access in non_ee_access_dict.items()
|
||||
}
|
||||
|
||||
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
if db_session is None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
return _get_access_for_documents(
|
||||
document_ids, cc_pair_to_delete, db_session
|
||||
)
|
||||
|
||||
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
@ -11,12 +11,16 @@ from danswer.db.tasks import mark_task_finished
|
||||
from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.user_group_sync import name_user_group_sync_task
|
||||
from ee.danswer.db.user_group import fetch_user_groups
|
||||
from ee.danswer.user_groups.sync import sync_user_groups
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# mark as EE for all tasks in this file
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_user_group_task(user_group_id: int) -> None:
|
||||
@ -25,9 +29,14 @@ def sync_user_group_task(user_group_id: int) -> None:
|
||||
mark_task_start(task_name, db_session)
|
||||
|
||||
# actual sync logic
|
||||
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
|
||||
error_msg = None
|
||||
try:
|
||||
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to sync user group - {error_msg}")
|
||||
|
||||
mark_task_finished(task_name, db_session)
|
||||
mark_task_finished(task_name, db_session, success=error_msg is None)
|
||||
|
||||
|
||||
#####
|
||||
|
@ -3,11 +3,11 @@ from sqlalchemy.orm import Session
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.access.access import get_access_for_documents
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_documents_for_user_group
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
|
Loading…
x
Reference in New Issue
Block a user