Fix group prefix (#11)

This commit is contained in:
Chris Weaver 2023-10-27 12:13:52 -07:00
parent d016e8335e
commit db8ce61ff4
7 changed files with 51 additions and 47 deletions

View File

@ -1,6 +1,7 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.models import User
@ -19,7 +20,7 @@ def _get_access_for_documents(
cc_pair_to_delete=cc_pair_to_delete,
)
return {
document_id: DocumentAccess.build(user_ids, is_public)
document_id: DocumentAccess.build(user_ids, [], is_public)
for document_id, user_ids, is_public in document_access_info
}
@ -38,12 +39,6 @@ def get_access_for_documents(
) # type: ignore
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The

View File

@ -1,20 +1,30 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_user_group
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
user_groups: set[str] # names of user groups associated with this document
is_public: bool
def to_acl(self) -> list[str]:
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
return (
[prefix_user(user_id) for user_id in self.user_ids]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
def build(
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
) -> "DocumentAccess":
return cls(
user_ids={str(user_id) for user_id in user_ids if user_id},
user_groups=set(user_groups),
is_public=is_public,
)

View File

@ -0,0 +1,10 @@
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"

View File

@ -3,8 +3,7 @@ from threading import Thread
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.db.document import get_acccess_info_for_documents
from danswer.access.access import get_access_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Document
from danswer.document_index.document_index_utils import get_both_index_names
@ -37,7 +36,7 @@ def set_acl_for_vespa(should_check_if_already_done: bool = False) -> None:
# for all documents, set the `access_control_list` field appropriately
# based on the state of Postgres
documents = db_session.scalars(select(Document)).all()
document_access_info = get_acccess_info_for_documents(
document_access_dict = get_access_for_documents(
db_session=db_session,
document_ids=[document.id for document in documents],
)
@ -46,15 +45,16 @@ def set_acl_for_vespa(should_check_if_already_done: bool = False) -> None:
vespa_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
if not isinstance(vespa_index, VespaIndex):
raise ValueError("This script is only for Vespa indexes")
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=DocumentAccess.build(user_ids, is_public),
access=access,
)
for document_id, user_ids, is_public in document_access_info
for document_id, access in document_access_dict.items()
]
vespa_index.update(update_requests=update_requests)

View File

@ -1,11 +1,11 @@
from sqlalchemy.orm import Session
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
from danswer.access.access import (
get_access_for_documents as get_access_for_documents_without_groups,
_get_access_for_documents as get_access_for_documents_without_groups,
)
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
from danswer.access.models import DocumentAccess
from danswer.db.engine import get_sqlalchemy_engine
from danswer.access.utils import prefix_user_group
from danswer.db.models import User
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from ee.danswer.db.user_group import fetch_user_groups_for_documents
@ -14,13 +14,13 @@ from ee.danswer.db.user_group import fetch_user_groups_for_user
def _get_access_for_documents(
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
) -> dict[str, DocumentAccess]:
access_dict = get_access_for_documents_without_groups(
non_ee_access_dict = get_access_for_documents_without_groups(
document_ids=document_ids,
cc_pair_to_delete=cc_pair_to_delete,
db_session=db_session,
cc_pair_to_delete=cc_pair_to_delete,
)
user_group_info = {
document_id: group_names
@ -31,36 +31,16 @@ def _get_access_for_documents(
)
}
# overload user_ids a bit - use it for both actual User IDs + group IDs
return {
document_id: DocumentAccess(
user_ids=access.user_ids.union(user_group_info.get(document_id, [])), # type: ignore
is_public=access.is_public,
user_ids=non_ee_access.user_ids,
user_groups=user_group_info.get(document_id, []), # type: ignore
is_public=non_ee_access.is_public,
)
for document_id, access in access_dict.items()
for document_id, non_ee_access in non_ee_access_dict.items()
}
def get_access_for_documents(
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
db_session: Session | None = None,
) -> dict[str, DocumentAccess]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
return _get_access_for_documents(
document_ids, cc_pair_to_delete, db_session
)
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The

View File

@ -11,12 +11,16 @@ from danswer.db.tasks import mark_task_finished
from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.background.user_group_sync import name_user_group_sync_task
from ee.danswer.db.user_group import fetch_user_groups
from ee.danswer.user_groups.sync import sync_user_groups
logger = setup_logger()
# mark as EE for all tasks in this file
global_version.set_ee()
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_user_group_task(user_group_id: int) -> None:
@ -25,9 +29,14 @@ def sync_user_group_task(user_group_id: int) -> None:
mark_task_start(task_name, db_session)
# actual sync logic
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
error_msg = None
try:
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
except Exception as e:
error_msg = str(e)
logger.exception(f"Failed to sync user group - {error_msg}")
mark_task_finished(task_name, db_session)
mark_task_finished(task_name, db_session, success=error_msg is None)
#####

View File

@ -3,11 +3,11 @@ from sqlalchemy.orm import Session
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.access.access import get_access_for_documents
from danswer.db.document import prepare_to_modify_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from ee.danswer.access.access import get_access_for_documents
from ee.danswer.db.user_group import delete_user_group
from ee.danswer.db.user_group import fetch_documents_for_user_group
from ee.danswer.db.user_group import fetch_user_group