mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Standardize Chat Message Stream (#1098)
This commit is contained in:
@@ -59,6 +59,7 @@ from danswer.search.search_runner import full_chunk_search_generator
|
|||||||
from danswer.search.search_runner import inference_documents_from_ids
|
from danswer.search.search_runner import inference_documents_from_ids
|
||||||
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||||
|
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||||
from danswer.server.utils import get_json_line
|
from danswer.server.utils import get_json_line
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -153,8 +154,7 @@ def translate_citations(
|
|||||||
return citation_to_saved_doc_id_map
|
return citation_to_saved_doc_id_map
|
||||||
|
|
||||||
|
|
||||||
@log_generator_function_time()
|
def stream_chat_message_objects(
|
||||||
def stream_chat_message(
|
|
||||||
new_msg_req: CreateChatMessageRequest,
|
new_msg_req: CreateChatMessageRequest,
|
||||||
user: User | None,
|
user: User | None,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
@@ -164,7 +164,14 @@ def stream_chat_message(
|
|||||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||||
) -> Iterator[str]:
|
) -> Iterator[
|
||||||
|
StreamingError
|
||||||
|
| QADocsResponse
|
||||||
|
| LLMRelevanceFilterResponse
|
||||||
|
| ChatMessageDetail
|
||||||
|
| DanswerAnswerPiece
|
||||||
|
| CitationInfo
|
||||||
|
]:
|
||||||
"""Streams in order:
|
"""Streams in order:
|
||||||
1. [conditional] Retrieved documents if a search needs to be run
|
1. [conditional] Retrieved documents if a search needs to be run
|
||||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||||
@@ -313,10 +320,8 @@ def stream_chat_message(
|
|||||||
# only allow the final document to get truncated
|
# only allow the final document to get truncated
|
||||||
# if more than that, then the user message is too long
|
# if more than that, then the user message is too long
|
||||||
if final_doc_ind != len(tokens_per_doc) - 1:
|
if final_doc_ind != len(tokens_per_doc) - 1:
|
||||||
yield get_json_line(
|
yield StreamingError(
|
||||||
StreamingError(
|
error="LLM context window exceeded. Please de-select some documents or shorten your query."
|
||||||
error="LLM context window exceeded. Please de-select some documents or shorten your query."
|
|
||||||
).dict()
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -417,8 +422,8 @@ def stream_chat_message(
|
|||||||
applied_source_filters=retrieval_request.filters.source_type,
|
applied_source_filters=retrieval_request.filters.source_type,
|
||||||
applied_time_cutoff=time_cutoff,
|
applied_time_cutoff=time_cutoff,
|
||||||
recency_bias_multiplier=recency_bias_multiplier,
|
recency_bias_multiplier=recency_bias_multiplier,
|
||||||
).dict()
|
)
|
||||||
yield get_json_line(initial_response)
|
yield initial_response
|
||||||
|
|
||||||
# Get the final ordering of chunks for the LLM call
|
# Get the final ordering of chunks for the LLM call
|
||||||
llm_chunk_selection = cast(list[bool], next(documents_generator))
|
llm_chunk_selection = cast(list[bool], next(documents_generator))
|
||||||
@@ -430,8 +435,8 @@ def stream_chat_message(
|
|||||||
]
|
]
|
||||||
if run_llm_chunk_filter
|
if run_llm_chunk_filter
|
||||||
else []
|
else []
|
||||||
).dict()
|
)
|
||||||
yield get_json_line(llm_relevance_filtering_response)
|
yield llm_relevance_filtering_response
|
||||||
|
|
||||||
# Prep chunks to pass to LLM
|
# Prep chunks to pass to LLM
|
||||||
num_llm_chunks = (
|
num_llm_chunks = (
|
||||||
@@ -497,7 +502,7 @@ def stream_chat_message(
|
|||||||
gen_ai_response_message
|
gen_ai_response_message
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_json_line(msg_detail_response.dict())
|
yield msg_detail_response
|
||||||
|
|
||||||
# Stop here after saving message details, the above still needs to be sent for the
|
# Stop here after saving message details, the above still needs to be sent for the
|
||||||
# message id to send the next follow-up message
|
# message id to send the next follow-up message
|
||||||
@@ -530,17 +535,13 @@ def stream_chat_message(
|
|||||||
citations.append(packet)
|
citations.append(packet)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield get_json_line(packet.dict())
|
yield packet
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
|
||||||
# Frontend will erase whatever answer and show this instead
|
# Frontend will erase whatever answer and show this instead
|
||||||
# This will be the issue 99% of the time
|
# This will be the issue 99% of the time
|
||||||
error_packet = StreamingError(
|
yield StreamingError(error="LLM failed to respond, have you set your API key?")
|
||||||
error="LLM failed to respond, have you set your API key?"
|
|
||||||
)
|
|
||||||
|
|
||||||
yield get_json_line(error_packet.dict())
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Post-LLM answer processing
|
# Post-LLM answer processing
|
||||||
@@ -564,11 +565,24 @@ def stream_chat_message(
|
|||||||
gen_ai_response_message
|
gen_ai_response_message
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_json_line(msg_detail_response.dict())
|
yield msg_detail_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
|
||||||
# Frontend will erase whatever answer and show this instead
|
# Frontend will erase whatever answer and show this instead
|
||||||
error_packet = StreamingError(error="Failed to parse LLM output")
|
yield StreamingError(error="Failed to parse LLM output")
|
||||||
|
|
||||||
yield get_json_line(error_packet.dict())
|
|
||||||
|
@log_generator_function_time()
|
||||||
|
def stream_chat_message(
|
||||||
|
new_msg_req: CreateChatMessageRequest,
|
||||||
|
user: User | None,
|
||||||
|
db_session: Session,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
objects = stream_chat_message_objects(
|
||||||
|
new_msg_req=new_msg_req,
|
||||||
|
user=user,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
for obj in objects:
|
||||||
|
yield get_json_line(obj.dict())
|
||||||
|
@@ -86,10 +86,10 @@ class RetrievalDetails(BaseModel):
|
|||||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||||
# If the Persona is configured to not run search (0 chunks), this is bypassed
|
# If the Persona is configured to not run search (0 chunks), this is bypassed
|
||||||
# If no Prompt is configured, the only search results are shown, this is bypassed
|
# If no Prompt is configured, the only search results are shown, this is bypassed
|
||||||
run_search: OptionalSearchSetting
|
run_search: OptionalSearchSetting = OptionalSearchSetting.ALWAYS
|
||||||
# Is this a real-time/streaming call or a question where Danswer can take more time?
|
# Is this a real-time/streaming call or a question where Danswer can take more time?
|
||||||
# Used to determine reranking flow
|
# Used to determine reranking flow
|
||||||
real_time: bool
|
real_time: bool = True
|
||||||
# The following have defaults in the Persona settings which can be overriden via
|
# The following have defaults in the Persona settings which can be overriden via
|
||||||
# the query, if None, then use Persona settings
|
# the query, if None, then use Persona settings
|
||||||
filters: BaseFilters | None = None
|
filters: BaseFilters | None = None
|
||||||
|
Reference in New Issue
Block a user