Add return contexts (#1018)

This commit is contained in:
Szymon Planeta 2024-02-02 07:22:22 +01:00 committed by GitHub
parent 0060a1dd58
commit dc2f4297b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 4 deletions

View File

@ -74,12 +74,24 @@ class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
class DanswerContext(BaseModel):
content: str
document_id: str
semantic_identifier: str
blurb: str
class DanswerContexts(BaseModel):
contexts: list[DanswerContext]
class DanswerAnswer(BaseModel):
answer: str | None
class QAResponse(SearchResponse, DanswerAnswer):
quotes: list[DanswerQuote] | None
contexts: list[DanswerContexts] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
@ -87,11 +99,8 @@ class QAResponse(SearchResponse, DanswerAnswer):
error_msg: str | None = None
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
AnswerQuestionStreamReturn = Iterator[
DanswerAnswerPiece | DanswerQuotes | StreamingError
DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError
]

View File

@ -1,3 +1,4 @@
import itertools
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
@ -6,6 +7,8 @@ from sqlalchemy.orm import Session
from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContext
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LLMMetricsContainer
from danswer.chat.models import LLMRelevanceFilterResponse
@ -67,6 +70,7 @@ def stream_answer_objects(
| LLMRelevanceFilterResponse
| DanswerAnswerPiece
| DanswerQuotes
| DanswerContexts
| StreamingError
| ChatMessageDetail
]:
@ -229,6 +233,21 @@ def stream_answer_objects(
else no_gen_ai_response()
)
if qa_model is not None and query_req.return_contexts:
contexts = DanswerContexts(
contexts=[
DanswerContext(
content=context_doc.content,
document_id=context_doc.document_id,
semantic_identifier=context_doc.semantic_identifier,
blurb=context_doc.semantic_identifier,
)
for context_doc in llm_chunks
]
)
response_packets = itertools.chain(response_packets, [contexts])
# Capture outputs and errors
llm_output = ""
error: str | None = None
@ -316,6 +335,8 @@ def get_search_answer(
qa_response.llm_chunks_indices = packet.relevant_chunk_indices
elif isinstance(packet, DanswerQuotes):
qa_response.quotes = packet
elif isinstance(packet, DanswerContexts):
qa_response.contexts = packet
elif isinstance(packet, StreamingError):
qa_response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):

View File

@ -3,6 +3,7 @@ from typing import Any
from pydantic import BaseModel
from pydantic import root_validator
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
@ -25,6 +26,7 @@ class DirectQARequest(BaseModel):
persona_id: int
retrieval_options: RetrievalDetails
chain_of_thought: bool = False
return_contexts: bool = False
@root_validator
def check_chain_of_thought_and_prompt_id(
@ -53,3 +55,4 @@ class OneShotQAResponse(BaseModel):
error_msg: str | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
chat_message_id: int | None = None
contexts: DanswerContexts | None = None