From da43bac456ad7cbcea23161784dc4f97a45323b2 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 6 Jun 2024 14:10:40 -0700 Subject: [PATCH] Dedupe Flag (#1576) --- backend/danswer/chat/process_message.py | 36 ++++++++++++++++--- .../one_shot_answer/answer_question.py | 18 +++++++++- backend/danswer/search/models.py | 4 +++ backend/danswer/search/utils.py | 29 +++++++++++++++ backend/danswer/tools/search/search_tool.py | 3 ++ 5 files changed, 84 insertions(+), 6 deletions(-) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 7022c06a6..95144ddab 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -49,6 +49,8 @@ from danswer.llm.utils import get_default_llm_tokenizer from danswer.search.enums import OptionalSearchSetting from danswer.search.retrieval.search_runner import inference_documents_from_ids from danswer.search.utils import chunks_or_sections_to_search_docs +from danswer.search.utils import dedupe_documents +from danswer.search.utils import drop_llm_indices from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line @@ -95,14 +97,20 @@ def _handle_search_tool_response_summary( packet: ToolResponse, db_session: Session, selected_search_docs: list[DbSearchDoc] | None, -) -> tuple[QADocsResponse, list[DbSearchDoc]]: + dedupe_docs: bool = False, +) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]: response_sumary = cast(SearchResponseSummary, packet.response) + dropped_inds = None if not selected_search_docs: top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections) + + if dedupe_docs: + deduped_docs, dropped_inds = dedupe_documents(top_docs) + reference_db_search_docs = [ - create_db_search_doc(server_search_doc=top_doc, db_session=db_session) - for top_doc in top_docs + create_db_search_doc(server_search_doc=doc, db_session=db_session) + for doc in deduped_docs ] else: reference_db_search_docs = selected_search_docs @@ -122,6 +130,7 @@ def _handle_search_tool_response_summary( recency_bias_multiplier=response_sumary.recency_bias_multiplier, ), reference_db_search_docs, + dropped_inds, ) @@ -460,19 +469,36 @@ def stream_chat_message_objects( reference_db_search_docs = None qa_docs_response = None ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images + dropped_indices = None for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): if packet.id == SEARCH_RESPONSE_SUMMARY_ID: ( qa_docs_response, reference_db_search_docs, + dropped_indices, ) = _handle_search_tool_response_summary( - packet, db_session, selected_db_search_docs + packet=packet, + db_session=db_session, + selected_search_docs=selected_db_search_docs, + # Deduping happens at the last step to avoid harming quality by dropping content early on + dedupe_docs=retrieval_options.dedupe_docs + if retrieval_options + else False, ) yield qa_docs_response elif packet.id == SECTION_RELEVANCE_LIST_ID: + chunk_indices = packet.response + + if reference_db_search_docs is not None and dropped_indices: + chunk_indices = drop_llm_indices( + llm_indices=chunk_indices, + search_docs=reference_db_search_docs, + dropped_indices=dropped_indices, + ) + yield LLMRelevanceFilterResponse( - relevant_chunk_indices=packet.response + relevant_chunk_indices=chunk_indices ) elif packet.id == IMAGE_GENERATION_RESPONSE_ID: img_generation_response = cast( diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 39e5d6c94..81dad2c65 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -39,6 +39,8 @@ from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.utils import chunks_or_sections_to_search_docs +from danswer.search.utils import dedupe_documents +from danswer.search.utils import drop_llm_indices from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -195,6 +197,7 @@ def stream_answer_objects( skip_explicit_tool_calling=True, ) # won't be any ImageGenerationDisplay responses since that tool is never passed in + dropped_inds: list[int] = [] for packet in cast(AnswerObjectIterator, answer.processed_streamed_output): # for one-shot flow, don't currently do anything with these if isinstance(packet, ToolResponse): @@ -205,11 +208,15 @@ def stream_answer_objects( search_response_summary.top_sections ) + deduped_docs = top_docs + if query_req.retrieval_options.dedupe_docs: + deduped_docs, dropped_inds = dedupe_documents(top_docs) + reference_db_search_docs = [ create_db_search_doc( server_search_doc=top_doc, db_session=db_session ) - for top_doc in top_docs + for top_doc in deduped_docs ] response_docs = [ @@ -228,6 +235,15 @@ def stream_answer_objects( ) yield initial_response elif packet.id == SECTION_RELEVANCE_LIST_ID: + chunk_indices = packet.response + + if reference_db_search_docs is not None and dropped_inds: + chunk_indices = drop_llm_indices( + llm_indices=chunk_indices, + search_docs=reference_db_search_docs, + dropped_indices=dropped_inds, + ) + yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response) else: yield packet diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 16a64d820..792043c31 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -78,6 +78,9 @@ class SearchRequest(ChunkContext): skip_rerank: bool | None = None skip_llm_chunk_filter: bool | None = None + # If this is set, only the highest matching chunk (or merged chunks) is returned + dedupe_docs: bool = False + class Config: arbitrary_types_allowed = True @@ -115,6 +118,7 @@ class RetrievalDetails(ChunkContext): # if None, no offset / limit offset: int | None = None limit: int | None = None + dedupe_docs: bool = False class InferenceChunk(BaseChunk): diff --git a/backend/danswer/search/utils.py b/backend/danswer/search/utils.py index fbcb205e3..5b5a6464a 100644 --- a/backend/danswer/search/utils.py +++ b/backend/danswer/search/utils.py @@ -1,10 +1,39 @@ from collections.abc import Sequence +from typing import TypeVar +from danswer.db.models import SearchDoc as DBSearchDoc from danswer.search.models import InferenceChunk from danswer.search.models import InferenceSection from danswer.search.models import SearchDoc +T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc) + + +def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]: + seen_ids = set() + deduped_items = [] + dropped_indices = [] + for index, item in enumerate(items): + if item.document_id not in seen_ids: + seen_ids.add(item.document_id) + deduped_items.append(item) + else: + dropped_indices.append(index) + return deduped_items, dropped_indices + + +def drop_llm_indices( + llm_indices: list[int], search_docs: list[DBSearchDoc], dropped_indices: list[int] +) -> list[int]: + llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))] + if dropped_indices: + llm_bools = [ + val for ind, val in enumerate(llm_bools) if ind not in dropped_indices + ] + return [i for i, val in enumerate(llm_bools) if val] + + def chunks_or_sections_to_search_docs( chunks: Sequence[InferenceChunk | InferenceSection] | None, ) -> list[SearchDoc]: diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 968c17f5a..e77a8dea9 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -202,6 +202,9 @@ class SearchTool(Tool): chunks_above=self.chunks_above, chunks_below=self.chunks_below, full_doc=self.full_doc, + dedupe_docs=self.retrieval_options.dedupe_docs + if self.retrieval_options + else False, ), user=self.user, db_session=self.db_session,