Compare commits

...

3 Commits

Author SHA1 Message Date
joachim-danswer
625936306f final examples and logging 2025-03-16 13:06:19 -07:00
joachim-danswer
ab11bf6552 writing data 2025-03-16 12:40:09 -07:00
joachim-danswer
83d5b3b503 question loop 2025-03-16 12:06:10 -07:00
2 changed files with 742 additions and 606 deletions

View File

@ -357,6 +357,38 @@ def stream_chat_message_objects(
llm: LLM llm: LLM
test_questions = [
"big bang vs steady state theory",
"astronomy",
"trace energy momentum tensor conformal field theory",
"evidence Big Bang",
"Neil Armstrong play tennis moon",
"current temperature Hawaii New York Munich",
"win quadradoodle",
"best practices coding Java",
"classes related software engineering",
"current temperature Munich",
"what is the most important concept in biology",
"subfields of finance",
"what is the overlap between finance and economics",
"effects taking vitamin c pills vs eating veggies health outcomes",
"professions people good math",
"biomedical engineers design cutting-edge medical equipment important skill set",
"How do biomedical engineers design cutting-edge medical equipment? And what is the most important skill set?",
"average power output US nuclear power plant",
"typical power range small modular reactors",
"SMRs power industry",
"best use case Onyx AI company",
"techniques calculate square root",
"daily vitamin C requirement adult women",
"boil ocean",
"best soccer player ever",
]
for test_question_num, test_question in enumerate(test_questions):
logger.info(
f"------- Running test question {test_question_num + 1} of {len(test_questions)}"
)
try: try:
user_id = user.id if user is not None else None user_id = user.id if user is not None else None
@ -366,7 +398,8 @@ def stream_chat_message_objects(
db_session=db_session, db_session=db_session,
) )
message_text = new_msg_req.message # message_text = new_msg_req.message
message_text = test_question
chat_session_id = new_msg_req.chat_session_id chat_session_id = new_msg_req.chat_session_id
parent_id = new_msg_req.parent_message_id parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids reference_doc_ids = new_msg_req.search_doc_ids
@ -375,7 +408,10 @@ def stream_chat_message_objects(
# permanent "log" store, used primarily for debugging # permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger( long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)} metadata={
"user_id": str(user_id),
"chat_session_id": str(chat_session_id),
}
) )
if alternate_assistant_id is not None: if alternate_assistant_id is not None:
@ -536,7 +572,9 @@ def stream_chat_message_objects(
history_msgs, new_msg_req.file_descriptors, db_session history_msgs, new_msg_req.file_descriptors, db_session
) )
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors] req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids] latest_query_files = [
file for file in files if file.file_id in req_file_ids
]
if user_message: if user_message:
attach_files_to_chat_message( attach_files_to_chat_message(
@ -609,7 +647,9 @@ def stream_chat_message_objects(
) )
overridden_model = ( overridden_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None new_msg_req.llm_override.model_version
if new_msg_req.llm_override
else None
) )
# Cannot determine these without the LLM step or breaking out early # Cannot determine these without the LLM step or breaking out early
@ -621,7 +661,9 @@ def stream_chat_message_objects(
# the latest. If we're creating a new assistant message, then the parent # the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message) # should be the latest message (latest user message)
parent_message=( parent_message=(
final_msg if existing_assistant_message_id is None else parent_message final_msg
if existing_assistant_message_id is None
else parent_message
), ),
prompt_id=prompt_id, prompt_id=prompt_id,
overridden_model=overridden_model, overridden_model=overridden_model,
@ -638,13 +680,17 @@ def stream_chat_message_objects(
is_agentic=new_msg_req.use_agentic_search, is_agentic=new_msg_req.use_agentic_search,
) )
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override prompt_override = (
new_msg_req.prompt_override or chat_session.prompt_override
)
if new_msg_req.persona_override_config: if new_msg_req.persona_override_config:
prompt_config = PromptConfig( prompt_config = PromptConfig(
system_prompt=new_msg_req.persona_override_config.prompts[ system_prompt=new_msg_req.persona_override_config.prompts[
0 0
].system_prompt, ].system_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt, task_prompt=new_msg_req.persona_override_config.prompts[
0
].task_prompt,
datetime_aware=new_msg_req.persona_override_config.prompts[ datetime_aware=new_msg_req.persona_override_config.prompts[
0 0
].datetime_aware, ].datetime_aware,
@ -864,7 +910,9 @@ def stream_chat_message_objects(
) )
file_ids = save_files( file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url], urls=[
img.url for img in img_generation_response if img.url
],
base64_files=[ base64_files=[
img.image_data img.image_data
for img in img_generation_response for img in img_generation_response
@ -890,7 +938,9 @@ def stream_chat_message_objects(
) )
yield info.qa_docs_response yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID: elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response) custom_tool_response = cast(
CustomToolCallSummary, packet.response
)
if ( if (
custom_tool_response.response_type == "image" custom_tool_response.response_type == "image"
@ -903,7 +953,8 @@ def stream_chat_message_objects(
id=str(file_id), id=str(file_id),
type=( type=(
ChatFileType.IMAGE ChatFileType.IMAGE
if custom_tool_response.response_type == "image" if custom_tool_response.response_type
== "image"
else ChatFileType.CSV else ChatFileType.CSV
), ),
) )
@ -967,7 +1018,9 @@ def stream_chat_message_objects(
llm.config.api_key, "[REDACTED_API_KEY]" llm.config.api_key, "[REDACTED_API_KEY]"
) )
yield StreamingError(error=client_error_msg, stack_trace=stack_trace) yield StreamingError(
error=client_error_msg, stack_trace=stack_trace
)
db_session.rollback() db_session.rollback()
return return

View File

@ -1,9 +1,12 @@
import csv
import json import json
import os
import string import string
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Mapping from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
from pathlib import Path
from typing import Any from typing import Any
from typing import cast from typing import cast
@ -285,15 +288,62 @@ def parallel_visit_api_retrieval(
return inference_chunks return inference_chunks
@retry(tries=3, delay=1, backoff=2) def _append_ranking_stats_to_csv(
ranking_stats: list[tuple[str, float, str, str, str, float]],
csv_path: str = "/tmp/ranking_stats.csv",
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(
["query_alpha", "query", "hit_position", "document_id", "relevance"]
)
# Write the ranking stats
for cat, query_alpha, query, hit_pos, doc_chunk_id, relevance in ranking_stats:
writer.writerow([cat, query_alpha, query, hit_pos, doc_chunk_id, relevance])
logger.debug(f"Appended {len(ranking_stats)} ranking stats to {csv_path}")
@retry(tries=1, delay=1, backoff=2)
def query_vespa( def query_vespa(
query_params: Mapping[str, str | int | float] query_params: Mapping[str, str | int | float]
) -> list[InferenceChunkUncleaned]: ) -> list[InferenceChunkUncleaned]:
if "query" in query_params and not cast(str, query_params["query"]).strip(): if "query" in query_params and not cast(str, query_params["query"]).strip():
raise ValueError("No/empty query received") raise ValueError("No/empty query received")
ranking_stats: list[tuple[str, float, str, str, str, float]] = []
search_time = 0.0
alphas: list[float] = [0.4, 0.7, 1.0]
for query_alpha in alphas:
date_time_start = datetime.now()
# Create a mutable copy of the query_params
mutable_params = dict(query_params)
# Now we can modify it without mypy errors
mutable_params["input.query(alpha)"] = query_alpha
params = dict( params = dict(
**query_params, **mutable_params,
**{ **{
"presentation.timing": True, "presentation.timing": True,
} }
@ -342,10 +392,43 @@ def query_vespa(
f"fetch this document" f"fetch this document"
) )
for hit_pos, hit in enumerate(hits):
ranking_stats.append(
(
"Retrieval",
query_alpha,
cast(str, mutable_params["query"]),
str(hit_pos),
hit["fields"].get("document_id", "")
+ "__"
+ str(hit["fields"].get("chunk_id", "")),
hit.get("relevance", 0),
)
)
date_time_end = datetime.now()
search_time += (date_time_end - date_time_start).microseconds / 1000000
avg_search_time = search_time / len(alphas)
ranking_stats.append(
(
"Timing",
query_alpha,
cast(str, query_params["query"]).strip(),
"",
"Avg:",
avg_search_time,
)
)
if ranking_stats:
_append_ranking_stats_to_csv(ranking_stats)
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None] filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits] inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
# Good Debugging Spot # Good Debugging Spot
logger.info(f"Search done for all alphs - avg timing: {avg_search_time}")
return inference_chunks return inference_chunks