From 48577bf0e4375075ffe046d40594d5434d70f297 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Tue, 24 Sep 2024 14:37:35 -0700 Subject: [PATCH] Allow = in tag filter (#2548) * Allow = in tag filter * Rename func --- backend/danswer/db/tag.py | 11 ++++++++-- .../server/query_and_chat/query_backend.py | 21 +++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/backend/danswer/db/tag.py b/backend/danswer/db/tag.py index 688b8a112..6f1985908 100644 --- a/backend/danswer/db/tag.py +++ b/backend/danswer/db/tag.py @@ -1,3 +1,4 @@ +from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import or_ @@ -107,12 +108,14 @@ def create_or_add_document_tag_list( return all_tags -def get_tags_by_value_prefix_for_source_types( +def find_tags( tag_key_prefix: str | None, tag_value_prefix: str | None, sources: list[DocumentSource] | None, limit: int | None, db_session: Session, + # if set, both tag_key_prefix and tag_value_prefix must be a match + require_both_to_match: bool = False, ) -> list[Tag]: query = select(Tag) @@ -122,7 +125,11 @@ def get_tags_by_value_prefix_for_source_types( conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%")) if tag_value_prefix: conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%")) - query = query.where(or_(*conditions)) + + final_prefix_condition = ( + and_(*conditions) if require_both_to_match else or_(*conditions) + ) + query = query.where(final_prefix_condition) if sources: query = query.where(Tag.source.in_(sources)) diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index e20de5a30..96f674276 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -18,7 +18,7 @@ from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.search_settings import get_current_search_settings -from danswer.db.tag import get_tags_by_value_prefix_for_source_types +from danswer.db.tag import find_tags from danswer.document_index.factory import get_default_document_index from danswer.document_index.vespa.index import VespaIndex from danswer.one_shot_answer.answer_question import stream_search_answer @@ -99,12 +99,25 @@ def get_tags( if not allow_prefix: raise NotImplementedError("Cannot disable prefix match for now") - db_tags = get_tags_by_value_prefix_for_source_types( - tag_key_prefix=match_pattern, - tag_value_prefix=match_pattern, + key_prefix = match_pattern + value_prefix = match_pattern + require_both_to_match = False + + # split on = to allow the user to type in "author=bob" + EQUAL_PAT = "=" + if match_pattern and EQUAL_PAT in match_pattern: + split_pattern = match_pattern.split(EQUAL_PAT) + key_prefix = split_pattern[0] + value_prefix = EQUAL_PAT.join(split_pattern[1:]) + require_both_to_match = True + + db_tags = find_tags( + tag_key_prefix=key_prefix, + tag_value_prefix=value_prefix, sources=sources, limit=limit, db_session=db_session, + require_both_to_match=require_both_to_match, ) server_tags = [ SourceTag(