mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Add user group prefix + access filter utility
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import User
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
|
||||
|
||||
@@ -34,3 +36,19 @@ def get_access_for_documents(
|
||||
)
|
||||
|
||||
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
|
||||
|
||||
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
|
||||
|
||||
def get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set."""
|
||||
if user:
|
||||
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
|
||||
return {PUBLIC_DOC_PAT}
|
||||
|
@@ -15,7 +15,6 @@ from requests import Response
|
||||
from danswer.chunking.models import DocMetadataAwareIndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import EDIT_KEYWORD_QUERY
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP
|
||||
from danswer.configs.app_configs import VESPA_HOST
|
||||
@@ -32,7 +31,6 @@ from danswer.configs.constants import DOCUMENT_SETS
|
||||
from danswer.configs.constants import EMBEDDINGS
|
||||
from danswer.configs.constants import MATCH_HIGHLIGHTS
|
||||
from danswer.configs.constants import METADATA
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.configs.constants import SCORE
|
||||
from danswer.configs.constants import SECTION_CONTINUATION
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
@@ -45,7 +43,6 @@ from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.datastores.vespa.utils import remove_invalid_unicode_chars
|
||||
from danswer.search.keyword_search import remove_stop_words
|
||||
from danswer.search.semantic_search import embed_query
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -252,18 +249,16 @@ def _index_vespa_chunks(
|
||||
return insertion_records
|
||||
|
||||
|
||||
def _build_vespa_filters(
|
||||
user_id: UUID | None, filters: list[IndexFilter] | None
|
||||
) -> str:
|
||||
# Permissions filters
|
||||
acl_filter_stmts = [f'{ACCESS_CONTROL_LIST} contains "{PUBLIC_DOC_PAT}"']
|
||||
if user_id:
|
||||
acl_filter_stmts.append(f'{ACCESS_CONTROL_LIST} contains "{user_id}"')
|
||||
filter_str = "(" + " or ".join(acl_filter_stmts) + ") and"
|
||||
def _build_vespa_filters(filters: list[IndexFilter] | None) -> str:
|
||||
# NOTE: permissions filters are expected to be passed in directly via
|
||||
# the `filters` arg, which is why they are not considered explicitly here
|
||||
|
||||
# TODO: have document sets passed in + add document set based filters
|
||||
# NOTE: document-set filters are also expected to be passed in directly
|
||||
# via the `filters` arg. These are set either in the Web UI or in the Slack
|
||||
# listener
|
||||
|
||||
# Provided query filters
|
||||
# Handle provided query filters
|
||||
filter_str = ""
|
||||
if filters:
|
||||
for filter_dict in filters:
|
||||
valid_filters = {
|
||||
@@ -497,7 +492,7 @@ class VespaIndex(DocumentIndex):
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = _build_vespa_filters(user_id, filters)
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
VespaIndex.yql_base
|
||||
+ vespa_where_clauses
|
||||
@@ -527,7 +522,7 @@ class VespaIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = _build_vespa_filters(user_id, filters)
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
VespaIndex.yql_base
|
||||
+ vespa_where_clauses
|
||||
@@ -540,13 +535,10 @@ class VespaIndex(DocumentIndex):
|
||||
)
|
||||
|
||||
query_embedding = embed_query(query)
|
||||
query_keywords = (
|
||||
" ".join(remove_stop_words(query)) if EDIT_KEYWORD_QUERY else query
|
||||
)
|
||||
|
||||
params = {
|
||||
"yql": yql,
|
||||
"query": query_keywords,
|
||||
"query": query,
|
||||
"input.query(query_embedding)": str(query_embedding),
|
||||
"ranking.profile": "semantic_search",
|
||||
}
|
||||
@@ -560,7 +552,7 @@ class VespaIndex(DocumentIndex):
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = _build_vespa_filters(user_id, filters)
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
VespaIndex.yql_base
|
||||
+ vespa_where_clauses
|
||||
|
@@ -16,6 +16,7 @@ from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
from danswer.search.models import QueryFlow
|
||||
@@ -66,11 +67,13 @@ def answer_qa_query(
|
||||
use_keyword = predicted_search == SearchType.KEYWORD
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
final_filters = (filters or []) + user_acl_filters
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
final_filters,
|
||||
get_default_document_index(),
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
)
|
||||
@@ -79,7 +82,7 @@ def answer_qa_query(
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
final_filters,
|
||||
get_default_document_index(),
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
|
22
backend/danswer/search/access_filters.py
Normal file
22
backend/danswer/search/access_filters.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import ACCESS_CONTROL_LIST
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def build_access_filters_for_user(
|
||||
user: User | None, session: Session
|
||||
) -> list[IndexFilter]:
|
||||
get_acl_for_user = cast(
|
||||
Callable[[User | None, Session], set[str]],
|
||||
fetch_versioned_implementation(
|
||||
module="danswer.access.access", attribute="get_acl_for_user"
|
||||
),
|
||||
)
|
||||
user_acl = get_acl_for_user(user, session)
|
||||
return [{ACCESS_CONTROL_LIST: list(user_acl)}]
|
@@ -23,6 +23,7 @@ from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
@@ -95,8 +96,10 @@ def semantic_search(
|
||||
)
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
final_filters = (filters or []) + user_acl_filters
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query, user_id, filters, get_default_document_index()
|
||||
query, user_id, final_filters, get_default_document_index()
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(
|
||||
@@ -132,8 +135,10 @@ def keyword_search(
|
||||
)
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
final_filters = (filters or []) + user_acl_filters
|
||||
ranked_chunks = retrieve_keyword_documents(
|
||||
query, user_id, filters, get_default_document_index()
|
||||
query, user_id, final_filters, get_default_document_index()
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(
|
||||
@@ -188,11 +193,13 @@ def stream_direct_qa(
|
||||
use_keyword = predicted_search == SearchType.KEYWORD
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
final_filters = (filters or []) + user_acl_filters
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
final_filters,
|
||||
get_default_document_index(),
|
||||
)
|
||||
unranked_chunks: list[InferenceChunk] | None = []
|
||||
@@ -200,7 +207,7 @@ def stream_direct_qa(
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
final_filters,
|
||||
get_default_document_index(),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
|
Reference in New Issue
Block a user