Dedupe Flag (#1576)

This commit is contained in:
Yuhong Sun 2024-06-06 14:10:40 -07:00 committed by GitHub
parent adcbd354f4
commit da43bac456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 84 additions and 6 deletions

View File

@ -49,6 +49,8 @@ from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting from danswer.search.enums import OptionalSearchSetting
from danswer.search.retrieval.search_runner import inference_documents_from_ids from danswer.search.retrieval.search_runner import inference_documents_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line from danswer.server.utils import get_json_line
@ -95,14 +97,20 @@ def _handle_search_tool_response_summary(
packet: ToolResponse, packet: ToolResponse,
db_session: Session, db_session: Session,
selected_search_docs: list[DbSearchDoc] | None, selected_search_docs: list[DbSearchDoc] | None,
) -> tuple[QADocsResponse, list[DbSearchDoc]]: dedupe_docs: bool = False,
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
response_sumary = cast(SearchResponseSummary, packet.response) response_sumary = cast(SearchResponseSummary, packet.response)
dropped_inds = None
if not selected_search_docs: if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections) top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
if dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [ reference_db_search_docs = [
create_db_search_doc(server_search_doc=top_doc, db_session=db_session) create_db_search_doc(server_search_doc=doc, db_session=db_session)
for top_doc in top_docs for doc in deduped_docs
] ]
else: else:
reference_db_search_docs = selected_search_docs reference_db_search_docs = selected_search_docs
@ -122,6 +130,7 @@ def _handle_search_tool_response_summary(
recency_bias_multiplier=response_sumary.recency_bias_multiplier, recency_bias_multiplier=response_sumary.recency_bias_multiplier,
), ),
reference_db_search_docs, reference_db_search_docs,
dropped_inds,
) )
@ -460,19 +469,36 @@ def stream_chat_message_objects(
reference_db_search_docs = None reference_db_search_docs = None
qa_docs_response = None qa_docs_response = None
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
for packet in answer.processed_streamed_output: for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse): if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID: if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
( (
qa_docs_response, qa_docs_response,
reference_db_search_docs, reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary( ) = _handle_search_tool_response_summary(
packet, db_session, selected_db_search_docs packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
) )
yield qa_docs_response yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID: elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
if reference_db_search_docs is not None and dropped_indices:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse( yield LLMRelevanceFilterResponse(
relevant_chunk_indices=packet.response relevant_chunk_indices=chunk_indices
) )
elif packet.id == IMAGE_GENERATION_RESPONSE_ID: elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast( img_generation_response = cast(

View File

@ -39,6 +39,8 @@ from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import RerankMetricsContainer from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import RetrievalMetricsContainer
from danswer.search.utils import chunks_or_sections_to_search_docs from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import ChatMessageDetail
@ -195,6 +197,7 @@ def stream_answer_objects(
skip_explicit_tool_calling=True, skip_explicit_tool_calling=True,
) )
# won't be any ImageGenerationDisplay responses since that tool is never passed in # won't be any ImageGenerationDisplay responses since that tool is never passed in
dropped_inds: list[int] = []
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output): for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these # for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse): if isinstance(packet, ToolResponse):
@ -205,11 +208,15 @@ def stream_answer_objects(
search_response_summary.top_sections search_response_summary.top_sections
) )
deduped_docs = top_docs
if query_req.retrieval_options.dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [ reference_db_search_docs = [
create_db_search_doc( create_db_search_doc(
server_search_doc=top_doc, db_session=db_session server_search_doc=top_doc, db_session=db_session
) )
for top_doc in top_docs for top_doc in deduped_docs
] ]
response_docs = [ response_docs = [
@ -228,6 +235,15 @@ def stream_answer_objects(
) )
yield initial_response yield initial_response
elif packet.id == SECTION_RELEVANCE_LIST_ID: elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
if reference_db_search_docs is not None and dropped_inds:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_inds,
)
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response) yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
else: else:
yield packet yield packet

View File

@ -78,6 +78,9 @@ class SearchRequest(ChunkContext):
skip_rerank: bool | None = None skip_rerank: bool | None = None
skip_llm_chunk_filter: bool | None = None skip_llm_chunk_filter: bool | None = None
# If this is set, only the highest matching chunk (or merged chunks) is returned
dedupe_docs: bool = False
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -115,6 +118,7 @@ class RetrievalDetails(ChunkContext):
# if None, no offset / limit # if None, no offset / limit
offset: int | None = None offset: int | None = None
limit: int | None = None limit: int | None = None
dedupe_docs: bool = False
class InferenceChunk(BaseChunk): class InferenceChunk(BaseChunk):

View File

@ -1,10 +1,39 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import TypeVar
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.search.models import InferenceChunk from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection from danswer.search.models import InferenceSection
from danswer.search.models import SearchDoc from danswer.search.models import SearchDoc
T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc)
def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
seen_ids = set()
deduped_items = []
dropped_indices = []
for index, item in enumerate(items):
if item.document_id not in seen_ids:
seen_ids.add(item.document_id)
deduped_items.append(item)
else:
dropped_indices.append(index)
return deduped_items, dropped_indices
def drop_llm_indices(
llm_indices: list[int], search_docs: list[DBSearchDoc], dropped_indices: list[int]
) -> list[int]:
llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))]
if dropped_indices:
llm_bools = [
val for ind, val in enumerate(llm_bools) if ind not in dropped_indices
]
return [i for i, val in enumerate(llm_bools) if val]
def chunks_or_sections_to_search_docs( def chunks_or_sections_to_search_docs(
chunks: Sequence[InferenceChunk | InferenceSection] | None, chunks: Sequence[InferenceChunk | InferenceSection] | None,
) -> list[SearchDoc]: ) -> list[SearchDoc]:

View File

@ -202,6 +202,9 @@ class SearchTool(Tool):
chunks_above=self.chunks_above, chunks_above=self.chunks_above,
chunks_below=self.chunks_below, chunks_below=self.chunks_below,
full_doc=self.full_doc, full_doc=self.full_doc,
dedupe_docs=self.retrieval_options.dedupe_docs
if self.retrieval_options
else False,
), ),
user=self.user, user=self.user,
db_session=self.db_session, db_session=self.db_session,