question loop

This commit is contained in:
joachim-danswer
2025-03-16 12:06:10 -07:00
parent f45798b5dd
commit 83d5b3b503

View File

@@ -357,6 +357,25 @@ def stream_chat_message_objects(
llm: LLM
test_questions = [
"weather in Munich",
"weather in New York",
# "what is the overlap between finance and economics",
# "effects taking vitamin c pills vs eating veggies health outcomes",
# "professions people good math",
# "biomedical engineers design cutting-edge medical equipment important skill set",
# "How do biomedical engineers design cutting-edge medical equipment? And what is the most important skill set?",
# "average power output US nuclear power plant",
# "typical power range small modular reactors",
# "SMRs power industry",
# "best use case Onyx AI company",
# "techniques calculate square root",
# "daily vitamin C requirement adult women",
# "boil ocean",
# "best soccer player ever"
]
for test_question in test_questions:
try:
user_id = user.id if user is not None else None
@@ -366,7 +385,8 @@ def stream_chat_message_objects(
db_session=db_session,
)
message_text = new_msg_req.message
# message_text = new_msg_req.message
message_text = test_question
chat_session_id = new_msg_req.chat_session_id
parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids
@@ -375,7 +395,10 @@ def stream_chat_message_objects(
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
metadata={
"user_id": str(user_id),
"chat_session_id": str(chat_session_id),
}
)
if alternate_assistant_id is not None:
@@ -536,7 +559,9 @@ def stream_chat_message_objects(
history_msgs, new_msg_req.file_descriptors, db_session
)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
latest_query_files = [
file for file in files if file.file_id in req_file_ids
]
if user_message:
attach_files_to_chat_message(
@@ -609,7 +634,9 @@ def stream_chat_message_objects(
)
overridden_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
new_msg_req.llm_override.model_version
if new_msg_req.llm_override
else None
)
# Cannot determine these without the LLM step or breaking out early
@@ -621,7 +648,9 @@ def stream_chat_message_objects(
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
final_msg
if existing_assistant_message_id is None
else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
@@ -638,13 +667,17 @@ def stream_chat_message_objects(
is_agentic=new_msg_req.use_agentic_search,
)
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
prompt_override = (
new_msg_req.prompt_override or chat_session.prompt_override
)
if new_msg_req.persona_override_config:
prompt_config = PromptConfig(
system_prompt=new_msg_req.persona_override_config.prompts[
0
].system_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[
0
].task_prompt,
datetime_aware=new_msg_req.persona_override_config.prompts[
0
].datetime_aware,
@@ -864,7 +897,9 @@ def stream_chat_message_objects(
)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
urls=[
img.url for img in img_generation_response if img.url
],
base64_files=[
img.image_data
for img in img_generation_response
@@ -890,7 +925,9 @@ def stream_chat_message_objects(
)
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
custom_tool_response = cast(
CustomToolCallSummary, packet.response
)
if (
custom_tool_response.response_type == "image"
@@ -903,7 +940,8 @@ def stream_chat_message_objects(
id=str(file_id),
type=(
ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
if custom_tool_response.response_type
== "image"
else ChatFileType.CSV
),
)
@@ -967,7 +1005,9 @@ def stream_chat_message_objects(
llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
yield StreamingError(
error=client_error_msg, stack_trace=stack_trace
)
db_session.rollback()
return