mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-30 09:40:50 +02:00
Make it harder to use unversioned access functions
This commit is contained in:
@ -3,15 +3,15 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.access.models import DocumentAccess
|
from danswer.access.models import DocumentAccess
|
||||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||||
from danswer.db.document import get_acccess_info_for_documents
|
from danswer.db.document import get_acccess_info_for_documents
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||||
|
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||||
|
|
||||||
|
|
||||||
def _get_access_for_documents(
|
def _get_access_for_documents(
|
||||||
document_ids: list[str],
|
document_ids: list[str],
|
||||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
|
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
|
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||||
) -> dict[str, DocumentAccess]:
|
) -> dict[str, DocumentAccess]:
|
||||||
document_access_info = get_acccess_info_for_documents(
|
document_access_info = get_acccess_info_for_documents(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@ -26,16 +26,16 @@ def _get_access_for_documents(
|
|||||||
|
|
||||||
def get_access_for_documents(
|
def get_access_for_documents(
|
||||||
document_ids: list[str],
|
document_ids: list[str],
|
||||||
|
db_session: Session,
|
||||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||||
db_session: Session | None = None,
|
|
||||||
) -> dict[str, DocumentAccess]:
|
) -> dict[str, DocumentAccess]:
|
||||||
if db_session is None:
|
"""Fetches all access information for the given documents."""
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
|
||||||
return _get_access_for_documents(
|
"danswer.access.access", "_get_access_for_documents"
|
||||||
document_ids, cc_pair_to_delete, db_session
|
|
||||||
)
|
)
|
||||||
|
return versioned_get_access_for_documents_fn(
|
||||||
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
|
document_ids, cc_pair_to_delete, db_session
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def prefix_user(user_id: str) -> str:
|
def prefix_user(user_id: str) -> str:
|
||||||
@ -44,11 +44,19 @@ def prefix_user(user_id: str) -> str:
|
|||||||
return f"user_id:{user_id}"
|
return f"user_id:{user_id}"
|
||||||
|
|
||||||
|
|
||||||
def get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||||
used downstream to filter out documents that the user does not have access to. The
|
used downstream to filter out documents that the user does not have access to. The
|
||||||
user should have access to a document if at least one entry in the document's ACL
|
user should have access to a document if at least one entry in the document's ACL
|
||||||
matches one entry in the returned set."""
|
matches one entry in the returned set.
|
||||||
|
"""
|
||||||
if user:
|
if user:
|
||||||
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
|
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
|
||||||
return {PUBLIC_DOC_PAT}
|
return {PUBLIC_DOC_PAT}
|
||||||
|
|
||||||
|
|
||||||
|
def get_acl_for_user(user: User | None, db_session: Session | None = None) -> set[str]:
|
||||||
|
versioned_acl_for_user_fn = fetch_versioned_implementation(
|
||||||
|
"danswer.access.access", "_get_acl_for_user"
|
||||||
|
)
|
||||||
|
return versioned_acl_for_user_fn(user, db_session) # type: ignore
|
||||||
|
@ -1,22 +1,13 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.access.access import get_acl_for_user
|
||||||
from danswer.configs.constants import ACCESS_CONTROL_LIST
|
from danswer.configs.constants import ACCESS_CONTROL_LIST
|
||||||
from danswer.datastores.interfaces import IndexFilter
|
from danswer.datastores.interfaces import IndexFilter
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
|
||||||
|
|
||||||
|
|
||||||
def build_access_filters_for_user(
|
def build_access_filters_for_user(
|
||||||
user: User | None, session: Session
|
user: User | None, session: Session
|
||||||
) -> list[IndexFilter]:
|
) -> list[IndexFilter]:
|
||||||
get_acl_for_user = cast(
|
|
||||||
Callable[[User | None, Session], set[str]],
|
|
||||||
fetch_versioned_implementation(
|
|
||||||
module="danswer.access.access", attribute="get_acl_for_user"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
user_acl = get_acl_for_user(user, session)
|
user_acl = get_acl_for_user(user, session)
|
||||||
return [{ACCESS_CONTROL_LIST: list(user_acl)}]
|
return [{ACCESS_CONTROL_LIST: list(user_acl)}]
|
||||||
|
Reference in New Issue
Block a user