diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index cd9f21b19b0..39524bb056c 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -61,7 +61,7 @@ from danswer.search.retrieval.search_runner import inference_sections_from_ids from danswer.search.utils import chunks_or_sections_to_search_docs from danswer.search.utils import dedupe_documents from danswer.search.utils import drop_llm_indices -from danswer.search.utils import relevant_documents_to_indices +from danswer.search.utils import relevant_sections_to_indices from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line @@ -637,9 +637,9 @@ def stream_chat_message_objects( relevance_sections = packet.response if reference_db_search_docs is not None: - llm_indices = relevant_documents_to_indices( + llm_indices = relevant_sections_to_indices( relevance_sections=relevance_sections, - search_docs=[ + items=[ translate_db_search_doc_to_server_search_doc(doc) for doc in reference_db_search_docs ], diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index 5c8067d96f8..117dde9b25e 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -402,6 +402,6 @@ class SearchPipeline: def section_relevance_list(self) -> list[bool]: llm_indices = relevant_sections_to_indices( relevance_sections=self.section_relevance, - inference_sections=self.final_context_sections, + items=self.final_context_sections, ) return [ind in llm_indices for ind in range(len(self.final_context_sections))] diff --git a/backend/danswer/search/utils.py b/backend/danswer/search/utils.py index 38ca2559ec5..21a95320ef5 100644 --- a/backend/danswer/search/utils.py +++ b/backend/danswer/search/utils.py @@ -19,6 +19,14 @@ T = TypeVar( SavedSearchDocWithContent, ) +TSection = TypeVar( + "TSection", + InferenceSection, + SearchDoc, + SavedSearchDoc, + SavedSearchDocWithContent, +) + def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]: seen_ids = set() @@ -39,30 +47,9 @@ def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]: def relevant_sections_to_indices( - relevance_sections: list[SectionRelevancePiece] | None, - inference_sections: list[InferenceSection], + relevance_sections: list[SectionRelevancePiece] | None, items: list[TSection] ) -> list[int]: - if relevance_sections is None: - return [] - - relevant_set = { - (chunk.document_id, chunk.chunk_id) - for chunk in relevance_sections - if chunk.relevant - } - relevant_indices = [ - index - for index, section in enumerate(inference_sections) - if (section.center_chunk.document_id, section.center_chunk.chunk_id) - in relevant_set - ] - return relevant_indices - - -def relevant_documents_to_indices( - relevance_sections: list[SectionRelevancePiece] | None, search_docs: list[SearchDoc] -) -> list[int]: - if relevance_sections is None: + if not relevance_sections: return [] relevant_set = { @@ -73,8 +60,18 @@ def relevant_documents_to_indices( return [ index - for index, section in enumerate(search_docs) - if (section.document_id, section.chunk_ind) in relevant_set + for index, item in enumerate(items) + if ( + ( + isinstance(item, InferenceSection) + and (item.center_chunk.document_id, item.center_chunk.chunk_id) + in relevant_set + ) + or ( + not isinstance(item, (InferenceSection)) + and (item.document_id, item.chunk_ind) in relevant_set + ) + ) ] diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py index 8c0c23286a3..6ef2a121d7a 100644 --- a/backend/ee/danswer/server/query_and_chat/query_backend.py +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -1,5 +1,3 @@ -from typing import cast - from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException @@ -11,9 +9,7 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTA from danswer.danswerbot.slack.handlers.handle_standard_answers import ( oneoff_standard_answers, ) -from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.engine import get_session -from danswer.db.models import SearchDoc from danswer.db.models import User from danswer.db.persona import get_persona_by_id from danswer.llm.answering.prompts.citations_prompt import ( @@ -31,7 +27,7 @@ from danswer.search.models import SearchRequest from danswer.search.pipeline import SearchPipeline from danswer.search.utils import dedupe_documents from danswer.search.utils import drop_llm_indices -from danswer.search.utils import relevant_documents_to_indices +from danswer.search.utils import relevant_sections_to_indices from danswer.utils.logger import setup_logger from ee.danswer.server.query_and_chat.models import DocumentSearchRequest from ee.danswer.server.query_and_chat.models import StandardAnswerRequest @@ -113,12 +109,8 @@ def handle_search_request( if search_request.retrieval_options.dedupe_docs: deduped_docs, dropped_inds = dedupe_documents(top_docs) - llm_indices = relevant_documents_to_indices( - relevance_sections=relevance_sections, - search_docs=[ - translate_db_search_doc_to_server_search_doc(cast(SearchDoc, doc)) - for doc in deduped_docs - ], + llm_indices = relevant_sections_to_indices( + relevance_sections=relevance_sections, items=deduped_docs ) if dropped_inds: