Make it harder to use unversioned access functions

This commit is contained in:
Weves
2023-10-11 17:24:24 -07:00
committed by Chris Weaver
parent 876c6fdaa6
commit 12442c1c06
2 changed files with 21 additions and 22 deletions

View File

@ -3,15 +3,15 @@ from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import User
from danswer.server.models import ConnectorCredentialPairIdentifier
from danswer.utils.variable_functionality import fetch_versioned_implementation
def _get_access_for_documents(
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> dict[str, DocumentAccess]:
document_access_info = get_acccess_info_for_documents(
db_session=db_session,
@ -26,16 +26,16 @@ def _get_access_for_documents(
def get_access_for_documents(
document_ids: list[str],
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
db_session: Session | None = None,
) -> dict[str, DocumentAccess]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
return _get_access_for_documents(
document_ids, cc_pair_to_delete, db_session
)
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
"""Fetches all access information for the given documents."""
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_access_for_documents"
)
return versioned_get_access_for_documents_fn(
document_ids, cc_pair_to_delete, db_session
) # type: ignore
def prefix_user(user_id: str) -> str:
@ -44,11 +44,19 @@ def prefix_user(user_id: str) -> str:
return f"user_id:{user_id}"
def get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The
user should have access to a document if at least one entry in the document's ACL
matches one entry in the returned set."""
matches one entry in the returned set.
"""
if user:
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
return {PUBLIC_DOC_PAT}
def get_acl_for_user(user: User | None, db_session: Session | None = None) -> set[str]:
versioned_acl_for_user_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_acl_for_user"
)
return versioned_acl_for_user_fn(user, db_session) # type: ignore

View File

@ -1,22 +1,13 @@
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from danswer.access.access import get_acl_for_user
from danswer.configs.constants import ACCESS_CONTROL_LIST
from danswer.datastores.interfaces import IndexFilter
from danswer.db.models import User
from danswer.utils.variable_functionality import fetch_versioned_implementation
def build_access_filters_for_user(
user: User | None, session: Session
) -> list[IndexFilter]:
get_acl_for_user = cast(
Callable[[User | None, Session], set[str]],
fetch_versioned_implementation(
module="danswer.access.access", attribute="get_acl_for_user"
),
)
user_acl = get_acl_for_user(user, session)
return [{ACCESS_CONTROL_LIST: list(user_acl)}]