From 3142e2eed20ce8b427615a182803a32d0b403910 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 10 Oct 2023 13:49:42 -0700 Subject: [PATCH] Add user group prefix + access filter utility --- backend/danswer/access/access.py | 18 +++++++++++ backend/danswer/datastores/vespa/store.py | 32 ++++++++------------ backend/danswer/direct_qa/answer_question.py | 7 +++-- backend/danswer/search/access_filters.py | 22 ++++++++++++++ backend/danswer/server/search_backend.py | 15 ++++++--- 5 files changed, 68 insertions(+), 26 deletions(-) create mode 100644 backend/danswer/search/access_filters.py diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py index e2c32a0bf319..507cf7104346 100644 --- a/backend/danswer/access/access.py +++ b/backend/danswer/access/access.py @@ -1,8 +1,10 @@ from sqlalchemy.orm import Session from danswer.access.models import DocumentAccess +from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.db.document import get_acccess_info_for_documents from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import User from danswer.server.models import ConnectorCredentialPairIdentifier @@ -34,3 +36,19 @@ def get_access_for_documents( ) return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session) + + +def prefix_user(user_id: str) -> str: + """Prefixes a user ID to eliminate collision with group names. + This assumes that groups are prefixed with a different prefix.""" + return f"user_id:{user_id}" + + +def get_acl_for_user(user: User | None, db_session: Session) -> set[str]: + """Returns a list of ACL entries that the user has access to. This is meant to be + used downstream to filter out documents that the user does not have access to. The + user should have access to a document if at least one entry in the document's ACL + matches one entry in the returned set.""" + if user: + return {prefix_user(str(user.id)), PUBLIC_DOC_PAT} + return {PUBLIC_DOC_PAT} diff --git a/backend/danswer/datastores/vespa/store.py b/backend/danswer/datastores/vespa/store.py index fa787a710e3d..fd5a92df60dd 100644 --- a/backend/danswer/datastores/vespa/store.py +++ b/backend/danswer/datastores/vespa/store.py @@ -15,7 +15,6 @@ from requests import Response from danswer.chunking.models import DocMetadataAwareIndexChunk from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DOCUMENT_INDEX_NAME -from danswer.configs.app_configs import EDIT_KEYWORD_QUERY from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP from danswer.configs.app_configs import VESPA_HOST @@ -32,7 +31,6 @@ from danswer.configs.constants import DOCUMENT_SETS from danswer.configs.constants import EMBEDDINGS from danswer.configs.constants import MATCH_HIGHLIGHTS from danswer.configs.constants import METADATA -from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.configs.constants import SCORE from danswer.configs.constants import SECTION_CONTINUATION from danswer.configs.constants import SEMANTIC_IDENTIFIER @@ -45,7 +43,6 @@ from danswer.datastores.interfaces import DocumentInsertionRecord from danswer.datastores.interfaces import IndexFilter from danswer.datastores.interfaces import UpdateRequest from danswer.datastores.vespa.utils import remove_invalid_unicode_chars -from danswer.search.keyword_search import remove_stop_words from danswer.search.semantic_search import embed_query from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger @@ -252,18 +249,16 @@ def _index_vespa_chunks( return insertion_records -def _build_vespa_filters( - user_id: UUID | None, filters: list[IndexFilter] | None -) -> str: - # Permissions filters - acl_filter_stmts = [f'{ACCESS_CONTROL_LIST} contains "{PUBLIC_DOC_PAT}"'] - if user_id: - acl_filter_stmts.append(f'{ACCESS_CONTROL_LIST} contains "{user_id}"') - filter_str = "(" + " or ".join(acl_filter_stmts) + ") and" +def _build_vespa_filters(filters: list[IndexFilter] | None) -> str: + # NOTE: permissions filters are expected to be passed in directly via + # the `filters` arg, which is why they are not considered explicitly here - # TODO: have document sets passed in + add document set based filters + # NOTE: document-set filters are also expected to be passed in directly + # via the `filters` arg. These are set either in the Web UI or in the Slack + # listener - # Provided query filters + # Handle provided query filters + filter_str = "" if filters: for filter_dict in filters: valid_filters = { @@ -497,7 +492,7 @@ class VespaIndex(DocumentIndex): filters: list[IndexFilter] | None, num_to_retrieve: int = NUM_RETURNED_HITS, ) -> list[InferenceChunk]: - vespa_where_clauses = _build_vespa_filters(user_id, filters) + vespa_where_clauses = _build_vespa_filters(filters) yql = ( VespaIndex.yql_base + vespa_where_clauses @@ -527,7 +522,7 @@ class VespaIndex(DocumentIndex): num_to_retrieve: int, distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF, ) -> list[InferenceChunk]: - vespa_where_clauses = _build_vespa_filters(user_id, filters) + vespa_where_clauses = _build_vespa_filters(filters) yql = ( VespaIndex.yql_base + vespa_where_clauses @@ -540,13 +535,10 @@ class VespaIndex(DocumentIndex): ) query_embedding = embed_query(query) - query_keywords = ( - " ".join(remove_stop_words(query)) if EDIT_KEYWORD_QUERY else query - ) params = { "yql": yql, - "query": query_keywords, + "query": query, "input.query(query_embedding)": str(query_embedding), "ranking.profile": "semantic_search", } @@ -560,7 +552,7 @@ class VespaIndex(DocumentIndex): filters: list[IndexFilter] | None, num_to_retrieve: int, ) -> list[InferenceChunk]: - vespa_where_clauses = _build_vespa_filters(user_id, filters) + vespa_where_clauses = _build_vespa_filters(filters) yql = ( VespaIndex.yql_base + vespa_where_clauses diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index edb5d5334f2a..878dd974cd05 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -16,6 +16,7 @@ from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.qa_utils import get_usable_chunks +from danswer.search.access_filters import build_access_filters_for_user from danswer.search.danswer_helper import query_intent from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.models import QueryFlow @@ -66,11 +67,13 @@ def answer_qa_query( use_keyword = predicted_search == SearchType.KEYWORD user_id = None if user is None else user.id + user_acl_filters = build_access_filters_for_user(user, db_session) + final_filters = (filters or []) + user_acl_filters if use_keyword: ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents( query, user_id, - filters, + final_filters, get_default_document_index(), retrieval_metrics_callback=retrieval_metrics_callback, ) @@ -79,7 +82,7 @@ def answer_qa_query( ranked_chunks, unranked_chunks = retrieve_ranked_documents( query, user_id, - filters, + final_filters, get_default_document_index(), retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, diff --git a/backend/danswer/search/access_filters.py b/backend/danswer/search/access_filters.py new file mode 100644 index 000000000000..eeac5cd23515 --- /dev/null +++ b/backend/danswer/search/access_filters.py @@ -0,0 +1,22 @@ +from collections.abc import Callable +from typing import cast + +from sqlalchemy.orm import Session + +from danswer.configs.constants import ACCESS_CONTROL_LIST +from danswer.datastores.interfaces import IndexFilter +from danswer.db.models import User +from danswer.utils.variable_functionality import fetch_versioned_implementation + + +def build_access_filters_for_user( + user: User | None, session: Session +) -> list[IndexFilter]: + get_acl_for_user = cast( + Callable[[User | None, Session], set[str]], + fetch_versioned_implementation( + module="danswer.access.access", attribute="get_acl_for_user" + ), + ) + user_acl = get_acl_for_user(user, session) + return [{ACCESS_CONTROL_LIST: list(user_acl)}] diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index ea4e5a9e92dd..0ab9e02ceb02 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -23,6 +23,7 @@ from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.qa_utils import get_usable_chunks +from danswer.search.access_filters import build_access_filters_for_user from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import recommend_search_flow from danswer.search.keyword_search import retrieve_keyword_documents @@ -95,8 +96,10 @@ def semantic_search( ) user_id = None if user is None else user.id + user_acl_filters = build_access_filters_for_user(user, db_session) + final_filters = (filters or []) + user_acl_filters ranked_chunks, unranked_chunks = retrieve_ranked_documents( - query, user_id, filters, get_default_document_index() + query, user_id, final_filters, get_default_document_index() ) if not ranked_chunks: return SearchResponse( @@ -132,8 +135,10 @@ def keyword_search( ) user_id = None if user is None else user.id + user_acl_filters = build_access_filters_for_user(user, db_session) + final_filters = (filters or []) + user_acl_filters ranked_chunks = retrieve_keyword_documents( - query, user_id, filters, get_default_document_index() + query, user_id, final_filters, get_default_document_index() ) if not ranked_chunks: return SearchResponse( @@ -188,11 +193,13 @@ def stream_direct_qa( use_keyword = predicted_search == SearchType.KEYWORD user_id = None if user is None else user.id + user_acl_filters = build_access_filters_for_user(user, db_session) + final_filters = (filters or []) + user_acl_filters if use_keyword: ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents( query, user_id, - filters, + final_filters, get_default_document_index(), ) unranked_chunks: list[InferenceChunk] | None = [] @@ -200,7 +207,7 @@ def stream_direct_qa( ranked_chunks, unranked_chunks = retrieve_ranked_documents( query, user_id, - filters, + final_filters, get_default_document_index(), ) if not ranked_chunks: