mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-25 19:37:29 +02:00
Add top_documents to APIs (#2469)
* Add top_documents * Fix test --------- Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
@@ -28,6 +28,7 @@ from danswer.natural_language_processing.utils import get_tokenizer
|
|||||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||||
from danswer.search.models import OptionalSearchSetting
|
from danswer.search.models import OptionalSearchSetting
|
||||||
from danswer.search.models import RetrievalDetails
|
from danswer.search.models import RetrievalDetails
|
||||||
|
from danswer.search.models import SavedSearchDoc
|
||||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||||
@@ -65,18 +66,18 @@ def _translate_doc_response_to_simple_doc(
|
|||||||
|
|
||||||
def _get_final_context_doc_indices(
|
def _get_final_context_doc_indices(
|
||||||
final_context_docs: list[LlmDoc] | None,
|
final_context_docs: list[LlmDoc] | None,
|
||||||
simple_search_docs: list[SimpleDoc] | None,
|
top_docs: list[SavedSearchDoc] | None,
|
||||||
) -> list[int] | None:
|
) -> list[int] | None:
|
||||||
"""
|
"""
|
||||||
this function returns a list of indices of the simple search docs
|
this function returns a list of indices of the simple search docs
|
||||||
that were actually fed to the LLM.
|
that were actually fed to the LLM.
|
||||||
"""
|
"""
|
||||||
if final_context_docs is None or simple_search_docs is None:
|
if final_context_docs is None or top_docs is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
|
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
|
||||||
return [
|
return [
|
||||||
i for i, doc in enumerate(simple_search_docs) if doc.id in final_context_doc_ids
|
i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -148,6 +149,7 @@ def handle_simplified_chat_message(
|
|||||||
answer += packet.answer_piece
|
answer += packet.answer_piece
|
||||||
elif isinstance(packet, QADocsResponse):
|
elif isinstance(packet, QADocsResponse):
|
||||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||||
|
response.top_documents = packet.top_documents
|
||||||
elif isinstance(packet, StreamingError):
|
elif isinstance(packet, StreamingError):
|
||||||
response.error_msg = packet.error
|
response.error_msg = packet.error
|
||||||
elif isinstance(packet, ChatMessageDetail):
|
elif isinstance(packet, ChatMessageDetail):
|
||||||
@@ -161,7 +163,7 @@ def handle_simplified_chat_message(
|
|||||||
}
|
}
|
||||||
|
|
||||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||||
final_context_docs, response.simple_search_docs
|
final_context_docs, response.top_documents
|
||||||
)
|
)
|
||||||
|
|
||||||
response.answer = answer
|
response.answer = answer
|
||||||
@@ -296,6 +298,7 @@ def handle_send_message_simple_with_history(
|
|||||||
answer += packet.answer_piece
|
answer += packet.answer_piece
|
||||||
elif isinstance(packet, QADocsResponse):
|
elif isinstance(packet, QADocsResponse):
|
||||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||||
|
response.top_documents = packet.top_documents
|
||||||
elif isinstance(packet, StreamingError):
|
elif isinstance(packet, StreamingError):
|
||||||
response.error_msg = packet.error
|
response.error_msg = packet.error
|
||||||
elif isinstance(packet, ChatMessageDetail):
|
elif isinstance(packet, ChatMessageDetail):
|
||||||
@@ -311,7 +314,7 @@ def handle_send_message_simple_with_history(
|
|||||||
}
|
}
|
||||||
|
|
||||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||||
final_context_docs, response.simple_search_docs
|
final_context_docs, response.top_documents
|
||||||
)
|
)
|
||||||
|
|
||||||
response.answer = answer
|
response.answer = answer
|
||||||
|
@@ -8,6 +8,7 @@ from danswer.search.enums import SearchType
|
|||||||
from danswer.search.models import ChunkContext
|
from danswer.search.models import ChunkContext
|
||||||
from danswer.search.models import RerankingDetails
|
from danswer.search.models import RerankingDetails
|
||||||
from danswer.search.models import RetrievalDetails
|
from danswer.search.models import RetrievalDetails
|
||||||
|
from danswer.search.models import SavedSearchDoc
|
||||||
from ee.danswer.server.manage.models import StandardAnswer
|
from ee.danswer.server.manage.models import StandardAnswer
|
||||||
|
|
||||||
|
|
||||||
@@ -73,7 +74,11 @@ class ChatBasicResponse(BaseModel):
|
|||||||
# This is built piece by piece, any of these can be None as the flow could break
|
# This is built piece by piece, any of these can be None as the flow could break
|
||||||
answer: str | None = None
|
answer: str | None = None
|
||||||
answer_citationless: str | None = None
|
answer_citationless: str | None = None
|
||||||
|
|
||||||
|
# TODO: deprecate `simple_search_docs`
|
||||||
simple_search_docs: list[SimpleDoc] | None = None
|
simple_search_docs: list[SimpleDoc] | None = None
|
||||||
|
top_documents: list[SavedSearchDoc] | None = None
|
||||||
|
|
||||||
error_msg: str | None = None
|
error_msg: str | None = None
|
||||||
message_id: int | None = None
|
message_id: int | None = None
|
||||||
llm_selected_doc_indices: list[int] | None = None
|
llm_selected_doc_indices: list[int] | None = None
|
||||||
|
@@ -51,6 +51,7 @@ def test_send_message_simple_with_history(reset: None) -> None:
|
|||||||
|
|
||||||
# Check that the top document is the correct document
|
# Check that the top document is the correct document
|
||||||
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
|
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
|
||||||
|
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
|
||||||
|
|
||||||
# assert that the metadata is correct
|
# assert that the metadata is correct
|
||||||
for doc in cc_pair_1.documents:
|
for doc in cc_pair_1.documents:
|
||||||
|
Reference in New Issue
Block a user