Add top_documents to APIs (#2469)

* Add top_documents

* Fix test

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
Chris Weaver
2024-09-16 16:48:33 -07:00
committed by GitHub
parent 8b2ecb4eab
commit 7ba829a585
3 changed files with 14 additions and 5 deletions

View File

@@ -28,6 +28,7 @@ from danswer.natural_language_processing.utils import get_tokenizer
from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import OptionalSearchSetting from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails from danswer.search.models import RetrievalDetails
from danswer.search.models import SavedSearchDoc
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.query_and_chat.models import CreateChatMessageRequest
@@ -65,18 +66,18 @@ def _translate_doc_response_to_simple_doc(
def _get_final_context_doc_indices( def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None, final_context_docs: list[LlmDoc] | None,
simple_search_docs: list[SimpleDoc] | None, top_docs: list[SavedSearchDoc] | None,
) -> list[int] | None: ) -> list[int] | None:
""" """
this function returns a list of indices of the simple search docs this function returns a list of indices of the simple search docs
that were actually fed to the LLM. that were actually fed to the LLM.
""" """
if final_context_docs is None or simple_search_docs is None: if final_context_docs is None or top_docs is None:
return None return None
final_context_doc_ids = {doc.document_id for doc in final_context_docs} final_context_doc_ids = {doc.document_id for doc in final_context_docs}
return [ return [
i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
] ]
@@ -148,6 +149,7 @@ def handle_simplified_chat_message(
answer += packet.answer_piece answer += packet.answer_piece
elif isinstance(packet, QADocsResponse): elif isinstance(packet, QADocsResponse):
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
response.top_documents = packet.top_documents
elif isinstance(packet, StreamingError): elif isinstance(packet, StreamingError):
response.error_msg = packet.error response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail): elif isinstance(packet, ChatMessageDetail):
@@ -161,7 +163,7 @@ def handle_simplified_chat_message(
} }
response.final_context_doc_indices = _get_final_context_doc_indices( response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.simple_search_docs final_context_docs, response.top_documents
) )
response.answer = answer response.answer = answer
@@ -296,6 +298,7 @@ def handle_send_message_simple_with_history(
answer += packet.answer_piece answer += packet.answer_piece
elif isinstance(packet, QADocsResponse): elif isinstance(packet, QADocsResponse):
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
response.top_documents = packet.top_documents
elif isinstance(packet, StreamingError): elif isinstance(packet, StreamingError):
response.error_msg = packet.error response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail): elif isinstance(packet, ChatMessageDetail):
@@ -311,7 +314,7 @@ def handle_send_message_simple_with_history(
} }
response.final_context_doc_indices = _get_final_context_doc_indices( response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.simple_search_docs final_context_docs, response.top_documents
) )
response.answer = answer response.answer = answer

View File

@@ -8,6 +8,7 @@ from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails from danswer.search.models import RetrievalDetails
from danswer.search.models import SavedSearchDoc
from ee.danswer.server.manage.models import StandardAnswer from ee.danswer.server.manage.models import StandardAnswer
@@ -73,7 +74,11 @@ class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break # This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None answer: str | None = None
answer_citationless: str | None = None answer_citationless: str | None = None
# TODO: deprecate `simple_search_docs`
simple_search_docs: list[SimpleDoc] | None = None simple_search_docs: list[SimpleDoc] | None = None
top_documents: list[SavedSearchDoc] | None = None
error_msg: str | None = None error_msg: str | None = None
message_id: int | None = None message_id: int | None = None
llm_selected_doc_indices: list[int] | None = None llm_selected_doc_indices: list[int] | None = None

View File

@@ -51,6 +51,7 @@ def test_send_message_simple_with_history(reset: None) -> None:
# Check that the top document is the correct document # Check that the top document is the correct document
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
# assert that the metadata is correct # assert that the metadata is correct
for doc in cc_pair_1.documents: for doc in cc_pair_1.documents: