question loop

This commit is contained in:
joachim-danswer 2025-03-16 12:06:10 -07:00
parent f45798b5dd
commit 83d5b3b503

View File

@ -357,6 +357,25 @@ def stream_chat_message_objects(
llm: LLM llm: LLM
test_questions = [
"weather in Munich",
"weather in New York",
# "what is the overlap between finance and economics",
# "effects taking vitamin c pills vs eating veggies health outcomes",
# "professions people good math",
# "biomedical engineers design cutting-edge medical equipment important skill set",
# "How do biomedical engineers design cutting-edge medical equipment? And what is the most important skill set?",
# "average power output US nuclear power plant",
# "typical power range small modular reactors",
# "SMRs power industry",
# "best use case Onyx AI company",
# "techniques calculate square root",
# "daily vitamin C requirement adult women",
# "boil ocean",
# "best soccer player ever"
]
for test_question in test_questions:
try: try:
user_id = user.id if user is not None else None user_id = user.id if user is not None else None
@ -366,7 +385,8 @@ def stream_chat_message_objects(
db_session=db_session, db_session=db_session,
) )
message_text = new_msg_req.message # message_text = new_msg_req.message
message_text = test_question
chat_session_id = new_msg_req.chat_session_id chat_session_id = new_msg_req.chat_session_id
parent_id = new_msg_req.parent_message_id parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids reference_doc_ids = new_msg_req.search_doc_ids
@ -375,7 +395,10 @@ def stream_chat_message_objects(
# permanent "log" store, used primarily for debugging # permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger( long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)} metadata={
"user_id": str(user_id),
"chat_session_id": str(chat_session_id),
}
) )
if alternate_assistant_id is not None: if alternate_assistant_id is not None:
@ -536,7 +559,9 @@ def stream_chat_message_objects(
history_msgs, new_msg_req.file_descriptors, db_session history_msgs, new_msg_req.file_descriptors, db_session
) )
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors] req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids] latest_query_files = [
file for file in files if file.file_id in req_file_ids
]
if user_message: if user_message:
attach_files_to_chat_message( attach_files_to_chat_message(
@ -609,7 +634,9 @@ def stream_chat_message_objects(
) )
overridden_model = ( overridden_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None new_msg_req.llm_override.model_version
if new_msg_req.llm_override
else None
) )
# Cannot determine these without the LLM step or breaking out early # Cannot determine these without the LLM step or breaking out early
@ -621,7 +648,9 @@ def stream_chat_message_objects(
# the latest. If we're creating a new assistant message, then the parent # the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message) # should be the latest message (latest user message)
parent_message=( parent_message=(
final_msg if existing_assistant_message_id is None else parent_message final_msg
if existing_assistant_message_id is None
else parent_message
), ),
prompt_id=prompt_id, prompt_id=prompt_id,
overridden_model=overridden_model, overridden_model=overridden_model,
@ -638,13 +667,17 @@ def stream_chat_message_objects(
is_agentic=new_msg_req.use_agentic_search, is_agentic=new_msg_req.use_agentic_search,
) )
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override prompt_override = (
new_msg_req.prompt_override or chat_session.prompt_override
)
if new_msg_req.persona_override_config: if new_msg_req.persona_override_config:
prompt_config = PromptConfig( prompt_config = PromptConfig(
system_prompt=new_msg_req.persona_override_config.prompts[ system_prompt=new_msg_req.persona_override_config.prompts[
0 0
].system_prompt, ].system_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt, task_prompt=new_msg_req.persona_override_config.prompts[
0
].task_prompt,
datetime_aware=new_msg_req.persona_override_config.prompts[ datetime_aware=new_msg_req.persona_override_config.prompts[
0 0
].datetime_aware, ].datetime_aware,
@ -864,7 +897,9 @@ def stream_chat_message_objects(
) )
file_ids = save_files( file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url], urls=[
img.url for img in img_generation_response if img.url
],
base64_files=[ base64_files=[
img.image_data img.image_data
for img in img_generation_response for img in img_generation_response
@ -890,7 +925,9 @@ def stream_chat_message_objects(
) )
yield info.qa_docs_response yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID: elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response) custom_tool_response = cast(
CustomToolCallSummary, packet.response
)
if ( if (
custom_tool_response.response_type == "image" custom_tool_response.response_type == "image"
@ -903,7 +940,8 @@ def stream_chat_message_objects(
id=str(file_id), id=str(file_id),
type=( type=(
ChatFileType.IMAGE ChatFileType.IMAGE
if custom_tool_response.response_type == "image" if custom_tool_response.response_type
== "image"
else ChatFileType.CSV else ChatFileType.CSV
), ),
) )
@ -967,7 +1005,9 @@ def stream_chat_message_objects(
llm.config.api_key, "[REDACTED_API_KEY]" llm.config.api_key, "[REDACTED_API_KEY]"
) )
yield StreamingError(error=client_error_msg, stack_trace=stack_trace) yield StreamingError(
error=client_error_msg, stack_trace=stack_trace
)
db_session.rollback() db_session.rollback()
return return