mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-02 08:58:11 +02:00
Dedupe Flag (#1576)
This commit is contained in:
parent
adcbd354f4
commit
da43bac456
@ -49,6 +49,8 @@ from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
@ -95,14 +97,20 @@ def _handle_search_tool_response_summary(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
selected_search_docs: list[DbSearchDoc] | None,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
||||
dedupe_docs: bool = False,
|
||||
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||
response_sumary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
dropped_inds = None
|
||||
if not selected_search_docs:
|
||||
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
||||
|
||||
if dedupe_docs:
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
||||
for top_doc in top_docs
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
else:
|
||||
reference_db_search_docs = selected_search_docs
|
||||
@ -122,6 +130,7 @@ def _handle_search_tool_response_summary(
|
||||
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
||||
),
|
||||
reference_db_search_docs,
|
||||
dropped_inds,
|
||||
)
|
||||
|
||||
|
||||
@ -460,19 +469,36 @@ def stream_chat_message_objects(
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet, db_session, selected_db_search_docs
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
chunk_indices = packet.response
|
||||
|
||||
if reference_db_search_docs is not None and dropped_indices:
|
||||
chunk_indices = drop_llm_indices(
|
||||
llm_indices=chunk_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=packet.response
|
||||
relevant_chunk_indices=chunk_indices
|
||||
)
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
|
@ -39,6 +39,8 @@ from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
@ -195,6 +197,7 @@ def stream_answer_objects(
|
||||
skip_explicit_tool_calling=True,
|
||||
)
|
||||
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
||||
dropped_inds: list[int] = []
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
@ -205,11 +208,15 @@ def stream_answer_objects(
|
||||
search_response_summary.top_sections
|
||||
)
|
||||
|
||||
deduped_docs = top_docs
|
||||
if query_req.retrieval_options.dedupe_docs:
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(
|
||||
server_search_doc=top_doc, db_session=db_session
|
||||
)
|
||||
for top_doc in top_docs
|
||||
for top_doc in deduped_docs
|
||||
]
|
||||
|
||||
response_docs = [
|
||||
@ -228,6 +235,15 @@ def stream_answer_objects(
|
||||
)
|
||||
yield initial_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
chunk_indices = packet.response
|
||||
|
||||
if reference_db_search_docs is not None and dropped_inds:
|
||||
chunk_indices = drop_llm_indices(
|
||||
llm_indices=chunk_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_inds,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
|
||||
else:
|
||||
yield packet
|
||||
|
@ -78,6 +78,9 @@ class SearchRequest(ChunkContext):
|
||||
skip_rerank: bool | None = None
|
||||
skip_llm_chunk_filter: bool | None = None
|
||||
|
||||
# If this is set, only the highest matching chunk (or merged chunks) is returned
|
||||
dedupe_docs: bool = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@ -115,6 +118,7 @@ class RetrievalDetails(ChunkContext):
|
||||
# if None, no offset / limit
|
||||
offset: int | None = None
|
||||
limit: int | None = None
|
||||
dedupe_docs: bool = False
|
||||
|
||||
|
||||
class InferenceChunk(BaseChunk):
|
||||
|
@ -1,10 +1,39 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import SearchDoc
|
||||
|
||||
|
||||
T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc)
|
||||
|
||||
|
||||
def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
|
||||
seen_ids = set()
|
||||
deduped_items = []
|
||||
dropped_indices = []
|
||||
for index, item in enumerate(items):
|
||||
if item.document_id not in seen_ids:
|
||||
seen_ids.add(item.document_id)
|
||||
deduped_items.append(item)
|
||||
else:
|
||||
dropped_indices.append(index)
|
||||
return deduped_items, dropped_indices
|
||||
|
||||
|
||||
def drop_llm_indices(
|
||||
llm_indices: list[int], search_docs: list[DBSearchDoc], dropped_indices: list[int]
|
||||
) -> list[int]:
|
||||
llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))]
|
||||
if dropped_indices:
|
||||
llm_bools = [
|
||||
val for ind, val in enumerate(llm_bools) if ind not in dropped_indices
|
||||
]
|
||||
return [i for i, val in enumerate(llm_bools) if val]
|
||||
|
||||
|
||||
def chunks_or_sections_to_search_docs(
|
||||
chunks: Sequence[InferenceChunk | InferenceSection] | None,
|
||||
) -> list[SearchDoc]:
|
||||
|
@ -202,6 +202,9 @@ class SearchTool(Tool):
|
||||
chunks_above=self.chunks_above,
|
||||
chunks_below=self.chunks_below,
|
||||
full_doc=self.full_doc,
|
||||
dedupe_docs=self.retrieval_options.dedupe_docs
|
||||
if self.retrieval_options
|
||||
else False,
|
||||
),
|
||||
user=self.user,
|
||||
db_session=self.db_session,
|
||||
|
Loading…
x
Reference in New Issue
Block a user