Add user group prefix + access filter utility

This commit is contained in:
Weves
2023-10-10 13:49:42 -07:00
committed by Chris Weaver
parent 5deb12523e
commit 3142e2eed2
5 changed files with 68 additions and 26 deletions

View File

@@ -1,8 +1,10 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import User
from danswer.server.models import ConnectorCredentialPairIdentifier
@@ -34,3 +36,19 @@ def get_access_for_documents(
)
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
def get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The
user should have access to a document if at least one entry in the document's ACL
matches one entry in the returned set."""
if user:
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
return {PUBLIC_DOC_PAT}

View File

@@ -15,7 +15,6 @@ from requests import Response
from danswer.chunking.models import DocMetadataAwareIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import EDIT_KEYWORD_QUERY
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP
from danswer.configs.app_configs import VESPA_HOST
@@ -32,7 +31,6 @@ from danswer.configs.constants import DOCUMENT_SETS
from danswer.configs.constants import EMBEDDINGS
from danswer.configs.constants import MATCH_HIGHLIGHTS
from danswer.configs.constants import METADATA
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.configs.constants import SCORE
from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
@@ -45,7 +43,6 @@ from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import UpdateRequest
from danswer.datastores.vespa.utils import remove_invalid_unicode_chars
from danswer.search.keyword_search import remove_stop_words
from danswer.search.semantic_search import embed_query
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
@@ -252,18 +249,16 @@ def _index_vespa_chunks(
return insertion_records
def _build_vespa_filters(
user_id: UUID | None, filters: list[IndexFilter] | None
) -> str:
# Permissions filters
acl_filter_stmts = [f'{ACCESS_CONTROL_LIST} contains "{PUBLIC_DOC_PAT}"']
if user_id:
acl_filter_stmts.append(f'{ACCESS_CONTROL_LIST} contains "{user_id}"')
filter_str = "(" + " or ".join(acl_filter_stmts) + ") and"
def _build_vespa_filters(filters: list[IndexFilter] | None) -> str:
# NOTE: permissions filters are expected to be passed in directly via
# the `filters` arg, which is why they are not considered explicitly here
# TODO: have document sets passed in + add document set based filters
# NOTE: document-set filters are also expected to be passed in directly
# via the `filters` arg. These are set either in the Web UI or in the Slack
# listener
# Provided query filters
# Handle provided query filters
filter_str = ""
if filters:
for filter_dict in filters:
valid_filters = {
@@ -497,7 +492,7 @@ class VespaIndex(DocumentIndex):
filters: list[IndexFilter] | None,
num_to_retrieve: int = NUM_RETURNED_HITS,
) -> list[InferenceChunk]:
vespa_where_clauses = _build_vespa_filters(user_id, filters)
vespa_where_clauses = _build_vespa_filters(filters)
yql = (
VespaIndex.yql_base
+ vespa_where_clauses
@@ -527,7 +522,7 @@ class VespaIndex(DocumentIndex):
num_to_retrieve: int,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
) -> list[InferenceChunk]:
vespa_where_clauses = _build_vespa_filters(user_id, filters)
vespa_where_clauses = _build_vespa_filters(filters)
yql = (
VespaIndex.yql_base
+ vespa_where_clauses
@@ -540,13 +535,10 @@ class VespaIndex(DocumentIndex):
)
query_embedding = embed_query(query)
query_keywords = (
" ".join(remove_stop_words(query)) if EDIT_KEYWORD_QUERY else query
)
params = {
"yql": yql,
"query": query_keywords,
"query": query,
"input.query(query_embedding)": str(query_embedding),
"ranking.profile": "semantic_search",
}
@@ -560,7 +552,7 @@ class VespaIndex(DocumentIndex):
filters: list[IndexFilter] | None,
num_to_retrieve: int,
) -> list[InferenceChunk]:
vespa_where_clauses = _build_vespa_filters(user_id, filters)
vespa_where_clauses = _build_vespa_filters(filters)
yql = (
VespaIndex.yql_base
+ vespa_where_clauses

View File

@@ -16,6 +16,7 @@ from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
@@ -66,11 +67,13 @@ def answer_qa_query(
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else user.id
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = (filters or []) + user_acl_filters
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query,
user_id,
filters,
final_filters,
get_default_document_index(),
retrieval_metrics_callback=retrieval_metrics_callback,
)
@@ -79,7 +82,7 @@ def answer_qa_query(
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query,
user_id,
filters,
final_filters,
get_default_document_index(),
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,

View File

@@ -0,0 +1,22 @@
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from danswer.configs.constants import ACCESS_CONTROL_LIST
from danswer.datastores.interfaces import IndexFilter
from danswer.db.models import User
from danswer.utils.variable_functionality import fetch_versioned_implementation
def build_access_filters_for_user(
user: User | None, session: Session
) -> list[IndexFilter]:
get_acl_for_user = cast(
Callable[[User | None, Session], set[str]],
fetch_versioned_implementation(
module="danswer.access.access", attribute="get_acl_for_user"
),
)
user_acl = get_acl_for_user(user, session)
return [{ACCESS_CONTROL_LIST: list(user_acl)}]

View File

@@ -23,6 +23,7 @@ from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
@@ -95,8 +96,10 @@ def semantic_search(
)
user_id = None if user is None else user.id
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = (filters or []) + user_acl_filters
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, get_default_document_index()
query, user_id, final_filters, get_default_document_index()
)
if not ranked_chunks:
return SearchResponse(
@@ -132,8 +135,10 @@ def keyword_search(
)
user_id = None if user is None else user.id
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = (filters or []) + user_acl_filters
ranked_chunks = retrieve_keyword_documents(
query, user_id, filters, get_default_document_index()
query, user_id, final_filters, get_default_document_index()
)
if not ranked_chunks:
return SearchResponse(
@@ -188,11 +193,13 @@ def stream_direct_qa(
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else user.id
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = (filters or []) + user_acl_filters
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query,
user_id,
filters,
final_filters,
get_default_document_index(),
)
unranked_chunks: list[InferenceChunk] | None = []
@@ -200,7 +207,7 @@ def stream_direct_qa(
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query,
user_id,
filters,
final_filters,
get_default_document_index(),
)
if not ranked_chunks: