diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py index 507cf71043..9f9a1d6341 100644 --- a/backend/danswer/access/access.py +++ b/backend/danswer/access/access.py @@ -3,15 +3,15 @@ from sqlalchemy.orm import Session from danswer.access.models import DocumentAccess from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.db.document import get_acccess_info_for_documents -from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import User from danswer.server.models import ConnectorCredentialPairIdentifier +from danswer.utils.variable_functionality import fetch_versioned_implementation def _get_access_for_documents( document_ids: list[str], - cc_pair_to_delete: ConnectorCredentialPairIdentifier | None, db_session: Session, + cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None, ) -> dict[str, DocumentAccess]: document_access_info = get_acccess_info_for_documents( db_session=db_session, @@ -26,16 +26,16 @@ def _get_access_for_documents( def get_access_for_documents( document_ids: list[str], + db_session: Session, cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None, - db_session: Session | None = None, ) -> dict[str, DocumentAccess]: - if db_session is None: - with Session(get_sqlalchemy_engine()) as db_session: - return _get_access_for_documents( - document_ids, cc_pair_to_delete, db_session - ) - - return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session) + """Fetches all access information for the given documents.""" + versioned_get_access_for_documents_fn = fetch_versioned_implementation( + "danswer.access.access", "_get_access_for_documents" + ) + return versioned_get_access_for_documents_fn( + document_ids, cc_pair_to_delete, db_session + ) # type: ignore def prefix_user(user_id: str) -> str: @@ -44,11 +44,19 @@ def prefix_user(user_id: str) -> str: return f"user_id:{user_id}" -def get_acl_for_user(user: User | None, db_session: Session) -> set[str]: +def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: """Returns a list of ACL entries that the user has access to. This is meant to be used downstream to filter out documents that the user does not have access to. The user should have access to a document if at least one entry in the document's ACL - matches one entry in the returned set.""" + matches one entry in the returned set. + """ if user: return {prefix_user(str(user.id)), PUBLIC_DOC_PAT} return {PUBLIC_DOC_PAT} + + +def get_acl_for_user(user: User | None, db_session: Session | None = None) -> set[str]: + versioned_acl_for_user_fn = fetch_versioned_implementation( + "danswer.access.access", "_get_acl_for_user" + ) + return versioned_acl_for_user_fn(user, db_session) # type: ignore diff --git a/backend/danswer/search/access_filters.py b/backend/danswer/search/access_filters.py index eeac5cd235..e6781d612d 100644 --- a/backend/danswer/search/access_filters.py +++ b/backend/danswer/search/access_filters.py @@ -1,22 +1,13 @@ -from collections.abc import Callable -from typing import cast - from sqlalchemy.orm import Session +from danswer.access.access import get_acl_for_user from danswer.configs.constants import ACCESS_CONTROL_LIST from danswer.datastores.interfaces import IndexFilter from danswer.db.models import User -from danswer.utils.variable_functionality import fetch_versioned_implementation def build_access_filters_for_user( user: User | None, session: Session ) -> list[IndexFilter]: - get_acl_for_user = cast( - Callable[[User | None, Session], set[str]], - fetch_versioned_implementation( - module="danswer.access.access", attribute="get_acl_for_user" - ), - ) user_acl = get_acl_for_user(user, session) return [{ACCESS_CONTROL_LIST: list(user_acl)}]