diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index 555619823250..1e1639425332 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -28,6 +28,7 @@ from danswer.natural_language_processing.utils import get_tokenizer from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.search.models import OptionalSearchSetting from danswer.search.models import RetrievalDetails +from danswer.search.models import SavedSearchDoc from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest @@ -65,18 +66,18 @@ def _translate_doc_response_to_simple_doc( def _get_final_context_doc_indices( final_context_docs: list[LlmDoc] | None, - simple_search_docs: list[SimpleDoc] | None, + top_docs: list[SavedSearchDoc] | None, ) -> list[int] | None: """ this function returns a list of indices of the simple search docs that were actually fed to the LLM. """ - if final_context_docs is None or simple_search_docs is None: + if final_context_docs is None or top_docs is None: return None final_context_doc_ids = {doc.document_id for doc in final_context_docs} return [ - i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids + i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids ] @@ -148,6 +149,7 @@ def handle_simplified_chat_message( answer += packet.answer_piece elif isinstance(packet, QADocsResponse): response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) + response.top_documents = packet.top_documents elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): @@ -161,7 +163,7 @@ def handle_simplified_chat_message( } response.final_context_doc_indices = _get_final_context_doc_indices( - final_context_docs, response.simple_search_docs + final_context_docs, response.top_documents ) response.answer = answer @@ -296,6 +298,7 @@ def handle_send_message_simple_with_history( answer += packet.answer_piece elif isinstance(packet, QADocsResponse): response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) + response.top_documents = packet.top_documents elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): @@ -311,7 +314,7 @@ def handle_send_message_simple_with_history( } response.final_context_doc_indices = _get_final_context_doc_indices( - final_context_docs, response.simple_search_docs + final_context_docs, response.top_documents ) response.answer = answer diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index cc66c0efab91..be1cd3c6ef6e 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -8,6 +8,7 @@ from danswer.search.enums import SearchType from danswer.search.models import ChunkContext from danswer.search.models import RerankingDetails from danswer.search.models import RetrievalDetails +from danswer.search.models import SavedSearchDoc from ee.danswer.server.manage.models import StandardAnswer @@ -73,7 +74,11 @@ class ChatBasicResponse(BaseModel): # This is built piece by piece, any of these can be None as the flow could break answer: str | None = None answer_citationless: str | None = None + + # TODO: deprecate `simple_search_docs` simple_search_docs: list[SimpleDoc] | None = None + top_documents: list[SavedSearchDoc] | None = None + error_msg: str | None = None message_id: int | None = None llm_selected_doc_indices: list[int] | None = None diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index 981a9cbd026a..d4edcc583aae 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -51,6 +51,7 @@ def test_send_message_simple_with_history(reset: None) -> None: # Check that the top document is the correct document assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id + assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id # assert that the metadata is correct for doc in cc_pair_1.documents: