Dedupe Flag (#1576)

This commit is contained in:
Yuhong Sun 2024-06-06 14:10:40 -07:00 committed by GitHub
parent adcbd354f4
commit da43bac456
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 84 additions and 6 deletions

View File

@ -49,6 +49,8 @@ from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.retrieval.search_runner import inference_documents_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
@ -95,14 +97,20 @@ def _handle_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
selected_search_docs: list[DbSearchDoc] | None,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
dedupe_docs: bool = False,
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
response_sumary = cast(SearchResponseSummary, packet.response)
dropped_inds = None
if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
if dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
for top_doc in top_docs
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in deduped_docs
]
else:
reference_db_search_docs = selected_search_docs
@ -122,6 +130,7 @@ def _handle_search_tool_response_summary(
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
),
reference_db_search_docs,
dropped_inds,
)
@ -460,19 +469,36 @@ def stream_chat_message_objects(
reference_db_search_docs = None
qa_docs_response = None
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet, db_session, selected_db_search_docs
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
if reference_db_search_docs is not None and dropped_indices:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=packet.response
relevant_chunk_indices=chunk_indices
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(

View File

@ -39,6 +39,8 @@ from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
@ -195,6 +197,7 @@ def stream_answer_objects(
skip_explicit_tool_calling=True,
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
dropped_inds: list[int] = []
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
@ -205,11 +208,15 @@ def stream_answer_objects(
search_response_summary.top_sections
)
deduped_docs = top_docs
if query_req.retrieval_options.dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [
create_db_search_doc(
server_search_doc=top_doc, db_session=db_session
)
for top_doc in top_docs
for top_doc in deduped_docs
]
response_docs = [
@ -228,6 +235,15 @@ def stream_answer_objects(
)
yield initial_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
if reference_db_search_docs is not None and dropped_inds:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_inds,
)
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
else:
yield packet

View File

@ -78,6 +78,9 @@ class SearchRequest(ChunkContext):
skip_rerank: bool | None = None
skip_llm_chunk_filter: bool | None = None
# If this is set, only the highest matching chunk (or merged chunks) is returned
dedupe_docs: bool = False
class Config:
arbitrary_types_allowed = True
@ -115,6 +118,7 @@ class RetrievalDetails(ChunkContext):
# if None, no offset / limit
offset: int | None = None
limit: int | None = None
dedupe_docs: bool = False
class InferenceChunk(BaseChunk):

View File

@ -1,10 +1,39 @@
from collections.abc import Sequence
from typing import TypeVar
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import SearchDoc
T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc)
def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
seen_ids = set()
deduped_items = []
dropped_indices = []
for index, item in enumerate(items):
if item.document_id not in seen_ids:
seen_ids.add(item.document_id)
deduped_items.append(item)
else:
dropped_indices.append(index)
return deduped_items, dropped_indices
def drop_llm_indices(
llm_indices: list[int], search_docs: list[DBSearchDoc], dropped_indices: list[int]
) -> list[int]:
llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))]
if dropped_indices:
llm_bools = [
val for ind, val in enumerate(llm_bools) if ind not in dropped_indices
]
return [i for i, val in enumerate(llm_bools) if val]
def chunks_or_sections_to_search_docs(
chunks: Sequence[InferenceChunk | InferenceSection] | None,
) -> list[SearchDoc]:

View File

@ -202,6 +202,9 @@ class SearchTool(Tool):
chunks_above=self.chunks_above,
chunks_below=self.chunks_below,
full_doc=self.full_doc,
dedupe_docs=self.retrieval_options.dedupe_docs
if self.retrieval_options
else False,
),
user=self.user,
db_session=self.db_session,