mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-26 07:50:56 +02:00
Dedupe Flag (#1576)
This commit is contained in:
parent
adcbd354f4
commit
da43bac456
@ -49,6 +49,8 @@ from danswer.llm.utils import get_default_llm_tokenizer
|
|||||||
from danswer.search.enums import OptionalSearchSetting
|
from danswer.search.enums import OptionalSearchSetting
|
||||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||||
|
from danswer.search.utils import dedupe_documents
|
||||||
|
from danswer.search.utils import drop_llm_indices
|
||||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||||
from danswer.server.utils import get_json_line
|
from danswer.server.utils import get_json_line
|
||||||
@ -95,14 +97,20 @@ def _handle_search_tool_response_summary(
|
|||||||
packet: ToolResponse,
|
packet: ToolResponse,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
selected_search_docs: list[DbSearchDoc] | None,
|
selected_search_docs: list[DbSearchDoc] | None,
|
||||||
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
dedupe_docs: bool = False,
|
||||||
|
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
||||||
response_sumary = cast(SearchResponseSummary, packet.response)
|
response_sumary = cast(SearchResponseSummary, packet.response)
|
||||||
|
|
||||||
|
dropped_inds = None
|
||||||
if not selected_search_docs:
|
if not selected_search_docs:
|
||||||
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
||||||
|
|
||||||
|
if dedupe_docs:
|
||||||
|
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||||
|
|
||||||
reference_db_search_docs = [
|
reference_db_search_docs = [
|
||||||
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||||
for top_doc in top_docs
|
for doc in deduped_docs
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
reference_db_search_docs = selected_search_docs
|
reference_db_search_docs = selected_search_docs
|
||||||
@ -122,6 +130,7 @@ def _handle_search_tool_response_summary(
|
|||||||
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
||||||
),
|
),
|
||||||
reference_db_search_docs,
|
reference_db_search_docs,
|
||||||
|
dropped_inds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -460,19 +469,36 @@ def stream_chat_message_objects(
|
|||||||
reference_db_search_docs = None
|
reference_db_search_docs = None
|
||||||
qa_docs_response = None
|
qa_docs_response = None
|
||||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||||
|
dropped_indices = None
|
||||||
for packet in answer.processed_streamed_output:
|
for packet in answer.processed_streamed_output:
|
||||||
if isinstance(packet, ToolResponse):
|
if isinstance(packet, ToolResponse):
|
||||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||||
(
|
(
|
||||||
qa_docs_response,
|
qa_docs_response,
|
||||||
reference_db_search_docs,
|
reference_db_search_docs,
|
||||||
|
dropped_indices,
|
||||||
) = _handle_search_tool_response_summary(
|
) = _handle_search_tool_response_summary(
|
||||||
packet, db_session, selected_db_search_docs
|
packet=packet,
|
||||||
|
db_session=db_session,
|
||||||
|
selected_search_docs=selected_db_search_docs,
|
||||||
|
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||||
|
dedupe_docs=retrieval_options.dedupe_docs
|
||||||
|
if retrieval_options
|
||||||
|
else False,
|
||||||
)
|
)
|
||||||
yield qa_docs_response
|
yield qa_docs_response
|
||||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||||
|
chunk_indices = packet.response
|
||||||
|
|
||||||
|
if reference_db_search_docs is not None and dropped_indices:
|
||||||
|
chunk_indices = drop_llm_indices(
|
||||||
|
llm_indices=chunk_indices,
|
||||||
|
search_docs=reference_db_search_docs,
|
||||||
|
dropped_indices=dropped_indices,
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMRelevanceFilterResponse(
|
yield LLMRelevanceFilterResponse(
|
||||||
relevant_chunk_indices=packet.response
|
relevant_chunk_indices=chunk_indices
|
||||||
)
|
)
|
||||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||||
img_generation_response = cast(
|
img_generation_response = cast(
|
||||||
|
@ -39,6 +39,8 @@ from danswer.one_shot_answer.qa_utils import combine_message_thread
|
|||||||
from danswer.search.models import RerankMetricsContainer
|
from danswer.search.models import RerankMetricsContainer
|
||||||
from danswer.search.models import RetrievalMetricsContainer
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||||
|
from danswer.search.utils import dedupe_documents
|
||||||
|
from danswer.search.utils import drop_llm_indices
|
||||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||||
@ -195,6 +197,7 @@ def stream_answer_objects(
|
|||||||
skip_explicit_tool_calling=True,
|
skip_explicit_tool_calling=True,
|
||||||
)
|
)
|
||||||
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
||||||
|
dropped_inds: list[int] = []
|
||||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||||
# for one-shot flow, don't currently do anything with these
|
# for one-shot flow, don't currently do anything with these
|
||||||
if isinstance(packet, ToolResponse):
|
if isinstance(packet, ToolResponse):
|
||||||
@ -205,11 +208,15 @@ def stream_answer_objects(
|
|||||||
search_response_summary.top_sections
|
search_response_summary.top_sections
|
||||||
)
|
)
|
||||||
|
|
||||||
|
deduped_docs = top_docs
|
||||||
|
if query_req.retrieval_options.dedupe_docs:
|
||||||
|
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||||
|
|
||||||
reference_db_search_docs = [
|
reference_db_search_docs = [
|
||||||
create_db_search_doc(
|
create_db_search_doc(
|
||||||
server_search_doc=top_doc, db_session=db_session
|
server_search_doc=top_doc, db_session=db_session
|
||||||
)
|
)
|
||||||
for top_doc in top_docs
|
for top_doc in deduped_docs
|
||||||
]
|
]
|
||||||
|
|
||||||
response_docs = [
|
response_docs = [
|
||||||
@ -228,6 +235,15 @@ def stream_answer_objects(
|
|||||||
)
|
)
|
||||||
yield initial_response
|
yield initial_response
|
||||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||||
|
chunk_indices = packet.response
|
||||||
|
|
||||||
|
if reference_db_search_docs is not None and dropped_inds:
|
||||||
|
chunk_indices = drop_llm_indices(
|
||||||
|
llm_indices=chunk_indices,
|
||||||
|
search_docs=reference_db_search_docs,
|
||||||
|
dropped_indices=dropped_inds,
|
||||||
|
)
|
||||||
|
|
||||||
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
|
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
|
||||||
else:
|
else:
|
||||||
yield packet
|
yield packet
|
||||||
|
@ -78,6 +78,9 @@ class SearchRequest(ChunkContext):
|
|||||||
skip_rerank: bool | None = None
|
skip_rerank: bool | None = None
|
||||||
skip_llm_chunk_filter: bool | None = None
|
skip_llm_chunk_filter: bool | None = None
|
||||||
|
|
||||||
|
# If this is set, only the highest matching chunk (or merged chunks) is returned
|
||||||
|
dedupe_docs: bool = False
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@ -115,6 +118,7 @@ class RetrievalDetails(ChunkContext):
|
|||||||
# if None, no offset / limit
|
# if None, no offset / limit
|
||||||
offset: int | None = None
|
offset: int | None = None
|
||||||
limit: int | None = None
|
limit: int | None = None
|
||||||
|
dedupe_docs: bool = False
|
||||||
|
|
||||||
|
|
||||||
class InferenceChunk(BaseChunk):
|
class InferenceChunk(BaseChunk):
|
||||||
|
@ -1,10 +1,39 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||||
from danswer.search.models import InferenceChunk
|
from danswer.search.models import InferenceChunk
|
||||||
from danswer.search.models import InferenceSection
|
from danswer.search.models import InferenceSection
|
||||||
from danswer.search.models import SearchDoc
|
from danswer.search.models import SearchDoc
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", InferenceSection, InferenceChunk, SearchDoc)
|
||||||
|
|
||||||
|
|
||||||
|
def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
|
||||||
|
seen_ids = set()
|
||||||
|
deduped_items = []
|
||||||
|
dropped_indices = []
|
||||||
|
for index, item in enumerate(items):
|
||||||
|
if item.document_id not in seen_ids:
|
||||||
|
seen_ids.add(item.document_id)
|
||||||
|
deduped_items.append(item)
|
||||||
|
else:
|
||||||
|
dropped_indices.append(index)
|
||||||
|
return deduped_items, dropped_indices
|
||||||
|
|
||||||
|
|
||||||
|
def drop_llm_indices(
|
||||||
|
llm_indices: list[int], search_docs: list[DBSearchDoc], dropped_indices: list[int]
|
||||||
|
) -> list[int]:
|
||||||
|
llm_bools = [True if i in llm_indices else False for i in range(len(search_docs))]
|
||||||
|
if dropped_indices:
|
||||||
|
llm_bools = [
|
||||||
|
val for ind, val in enumerate(llm_bools) if ind not in dropped_indices
|
||||||
|
]
|
||||||
|
return [i for i, val in enumerate(llm_bools) if val]
|
||||||
|
|
||||||
|
|
||||||
def chunks_or_sections_to_search_docs(
|
def chunks_or_sections_to_search_docs(
|
||||||
chunks: Sequence[InferenceChunk | InferenceSection] | None,
|
chunks: Sequence[InferenceChunk | InferenceSection] | None,
|
||||||
) -> list[SearchDoc]:
|
) -> list[SearchDoc]:
|
||||||
|
@ -202,6 +202,9 @@ class SearchTool(Tool):
|
|||||||
chunks_above=self.chunks_above,
|
chunks_above=self.chunks_above,
|
||||||
chunks_below=self.chunks_below,
|
chunks_below=self.chunks_below,
|
||||||
full_doc=self.full_doc,
|
full_doc=self.full_doc,
|
||||||
|
dedupe_docs=self.retrieval_options.dedupe_docs
|
||||||
|
if self.retrieval_options
|
||||||
|
else False,
|
||||||
),
|
),
|
||||||
user=self.user,
|
user=self.user,
|
||||||
db_session=self.db_session,
|
db_session=self.db_session,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user