Add top_documents to APIs (#2469)

* Add top_documents

* Fix test

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
Chris Weaver
2024-09-16 16:48:33 -07:00
committed by GitHub
parent 8b2ecb4eab
commit 7ba829a585
3 changed files with 14 additions and 5 deletions

View File

@@ -28,6 +28,7 @@ from danswer.natural_language_processing.utils import get_tokenizer
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.search.models import SavedSearchDoc
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
@@ -65,18 +66,18 @@ def _translate_doc_response_to_simple_doc(
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
simple_search_docs: list[SimpleDoc] | None,
top_docs: list[SavedSearchDoc] | None,
) -> list[int] | None:
"""
this function returns a list of indices of the simple search docs
that were actually fed to the LLM.
"""
if final_context_docs is None or simple_search_docs is None:
if final_context_docs is None or top_docs is None:
return None
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
return [
i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids
i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
]
@@ -148,6 +149,7 @@ def handle_simplified_chat_message(
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
response.top_documents = packet.top_documents
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -161,7 +163,7 @@ def handle_simplified_chat_message(
}
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.simple_search_docs
final_context_docs, response.top_documents
)
response.answer = answer
@@ -296,6 +298,7 @@ def handle_send_message_simple_with_history(
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
response.top_documents = packet.top_documents
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -311,7 +314,7 @@ def handle_send_message_simple_with_history(
}
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.simple_search_docs
final_context_docs, response.top_documents
)
response.answer = answer

View File

@@ -8,6 +8,7 @@ from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.search.models import SavedSearchDoc
from ee.danswer.server.manage.models import StandardAnswer
@@ -73,7 +74,11 @@ class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
answer_citationless: str | None = None
# TODO: deprecate `simple_search_docs`
simple_search_docs: list[SimpleDoc] | None = None
top_documents: list[SavedSearchDoc] | None = None
error_msg: str | None = None
message_id: int | None = None
llm_selected_doc_indices: list[int] | None = None

View File

@@ -51,6 +51,7 @@ def test_send_message_simple_with_history(reset: None) -> None:
# Check that the top document is the correct document
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
# assert that the metadata is correct
for doc in cc_pair_1.documents: