From 12442c1c0600d8f6b7f0ab01bb67b20dd1f959d6 Mon Sep 17 00:00:00 2001 From: Weves Date: Wed, 11 Oct 2023 17:24:24 -0700 Subject: [PATCH] Make it harder to use unversioned access functions --- backend/danswer/access/access.py | 32 +++++++++++++++--------- backend/danswer/search/access_filters.py | 11 +------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py index 507cf71043..9f9a1d6341 100644 --- a/backend/danswer/access/access.py +++ b/backend/danswer/access/access.py @@ -3,15 +3,15 @@ from sqlalchemy.orm import Session from danswer.access.models import DocumentAccess from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.db.document import get_acccess_info_for_documents -from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import User from danswer.server.models import ConnectorCredentialPairIdentifier +from danswer.utils.variable_functionality import fetch_versioned_implementation def _get_access_for_documents( document_ids: list[str], - cc_pair_to_delete: ConnectorCredentialPairIdentifier | None, db_session: Session, + cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None, ) -> dict[str, DocumentAccess]: document_access_info = get_acccess_info_for_documents( db_session=db_session, @@ -26,16 +26,16 @@ def _get_access_for_documents( def get_access_for_documents( document_ids: list[str], + db_session: Session, cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None, - db_session: Session | None = None, ) -> dict[str, DocumentAccess]: - if db_session is None: - with Session(get_sqlalchemy_engine()) as db_session: - return _get_access_for_documents( - document_ids, cc_pair_to_delete, db_session - ) - - return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session) + """Fetches all access information for the given documents.""" + versioned_get_access_for_documents_fn = fetch_versioned_implementation( + "danswer.access.access", "_get_access_for_documents" + ) + return versioned_get_access_for_documents_fn( + document_ids, cc_pair_to_delete, db_session + ) # type: ignore def prefix_user(user_id: str) -> str: @@ -44,11 +44,19 @@ def prefix_user(user_id: str) -> str: return f"user_id:{user_id}" -def get_acl_for_user(user: User | None, db_session: Session) -> set[str]: +def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: """Returns a list of ACL entries that the user has access to. This is meant to be used downstream to filter out documents that the user does not have access to. The user should have access to a document if at least one entry in the document's ACL - matches one entry in the returned set.""" + matches one entry in the returned set. + """ if user: return {prefix_user(str(user.id)), PUBLIC_DOC_PAT} return {PUBLIC_DOC_PAT} + + +def get_acl_for_user(user: User | None, db_session: Session | None = None) -> set[str]: + versioned_acl_for_user_fn = fetch_versioned_implementation( + "danswer.access.access", "_get_acl_for_user" + ) + return versioned_acl_for_user_fn(user, db_session) # type: ignore diff --git a/backend/danswer/search/access_filters.py b/backend/danswer/search/access_filters.py index eeac5cd235..e6781d612d 100644 --- a/backend/danswer/search/access_filters.py +++ b/backend/danswer/search/access_filters.py @@ -1,22 +1,13 @@ -from collections.abc import Callable -from typing import cast - from sqlalchemy.orm import Session +from danswer.access.access import get_acl_for_user from danswer.configs.constants import ACCESS_CONTROL_LIST from danswer.datastores.interfaces import IndexFilter from danswer.db.models import User -from danswer.utils.variable_functionality import fetch_versioned_implementation def build_access_filters_for_user( user: User | None, session: Session ) -> list[IndexFilter]: - get_acl_for_user = cast( - Callable[[User | None, Session], set[str]], - fetch_versioned_implementation( - module="danswer.access.access", attribute="get_acl_for_user" - ), - ) user_acl = get_acl_for_user(user, session) return [{ACCESS_CONTROL_LIST: list(user_acl)}]