mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Add top_documents to APIs (#2469)
* Add top_documents * Fix test --------- Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
@@ -28,6 +28,7 @@ from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
@@ -65,18 +66,18 @@ def _translate_doc_response_to_simple_doc(
|
||||
|
||||
def _get_final_context_doc_indices(
|
||||
final_context_docs: list[LlmDoc] | None,
|
||||
simple_search_docs: list[SimpleDoc] | None,
|
||||
top_docs: list[SavedSearchDoc] | None,
|
||||
) -> list[int] | None:
|
||||
"""
|
||||
this function returns a list of indices of the simple search docs
|
||||
that were actually fed to the LLM.
|
||||
"""
|
||||
if final_context_docs is None or simple_search_docs is None:
|
||||
if final_context_docs is None or top_docs is None:
|
||||
return None
|
||||
|
||||
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
|
||||
return [
|
||||
i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids
|
||||
i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
|
||||
]
|
||||
|
||||
|
||||
@@ -148,6 +149,7 @@ def handle_simplified_chat_message(
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
response.top_documents = packet.top_documents
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
@@ -161,7 +163,7 @@ def handle_simplified_chat_message(
|
||||
}
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.simple_search_docs
|
||||
final_context_docs, response.top_documents
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
@@ -296,6 +298,7 @@ def handle_send_message_simple_with_history(
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
response.top_documents = packet.top_documents
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
@@ -311,7 +314,7 @@ def handle_send_message_simple_with_history(
|
||||
}
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.simple_search_docs
|
||||
final_context_docs, response.top_documents
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
|
@@ -8,6 +8,7 @@ from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from ee.danswer.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
@@ -73,7 +74,11 @@ class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
answer_citationless: str | None = None
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
top_documents: list[SavedSearchDoc] | None = None
|
||||
|
||||
error_msg: str | None = None
|
||||
message_id: int | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
|
@@ -51,6 +51,7 @@ def test_send_message_simple_with_history(reset: None) -> None:
|
||||
|
||||
# Check that the top document is the correct document
|
||||
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
|
||||
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
|
||||
|
||||
# assert that the metadata is correct
|
||||
for doc in cc_pair_1.documents:
|
||||
|
Reference in New Issue
Block a user