Compare commits

...

3 Commits

Author SHA1 Message Date
joachim-danswer
625936306f final examples and logging 2025-03-16 13:06:19 -07:00
joachim-danswer
ab11bf6552 writing data 2025-03-16 12:40:09 -07:00
joachim-danswer
83d5b3b503 question loop 2025-03-16 12:06:10 -07:00
2 changed files with 742 additions and 606 deletions

View File

@ -357,6 +357,38 @@ def stream_chat_message_objects(
llm: LLM
test_questions = [
"big bang vs steady state theory",
"astronomy",
"trace energy momentum tensor conformal field theory",
"evidence Big Bang",
"Neil Armstrong play tennis moon",
"current temperature Hawaii New York Munich",
"win quadradoodle",
"best practices coding Java",
"classes related software engineering",
"current temperature Munich",
"what is the most important concept in biology",
"subfields of finance",
"what is the overlap between finance and economics",
"effects taking vitamin c pills vs eating veggies health outcomes",
"professions people good math",
"biomedical engineers design cutting-edge medical equipment important skill set",
"How do biomedical engineers design cutting-edge medical equipment? And what is the most important skill set?",
"average power output US nuclear power plant",
"typical power range small modular reactors",
"SMRs power industry",
"best use case Onyx AI company",
"techniques calculate square root",
"daily vitamin C requirement adult women",
"boil ocean",
"best soccer player ever",
]
for test_question_num, test_question in enumerate(test_questions):
logger.info(
f"------- Running test question {test_question_num + 1} of {len(test_questions)}"
)
try:
user_id = user.id if user is not None else None
@ -366,7 +398,8 @@ def stream_chat_message_objects(
db_session=db_session,
)
message_text = new_msg_req.message
# message_text = new_msg_req.message
message_text = test_question
chat_session_id = new_msg_req.chat_session_id
parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids
@ -375,7 +408,10 @@ def stream_chat_message_objects(
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
metadata={
"user_id": str(user_id),
"chat_session_id": str(chat_session_id),
}
)
if alternate_assistant_id is not None:
@ -536,7 +572,9 @@ def stream_chat_message_objects(
history_msgs, new_msg_req.file_descriptors, db_session
)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
latest_query_files = [
file for file in files if file.file_id in req_file_ids
]
if user_message:
attach_files_to_chat_message(
@ -609,7 +647,9 @@ def stream_chat_message_objects(
)
overridden_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
new_msg_req.llm_override.model_version
if new_msg_req.llm_override
else None
)
# Cannot determine these without the LLM step or breaking out early
@ -621,7 +661,9 @@ def stream_chat_message_objects(
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
final_msg
if existing_assistant_message_id is None
else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
@ -638,13 +680,17 @@ def stream_chat_message_objects(
is_agentic=new_msg_req.use_agentic_search,
)
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
prompt_override = (
new_msg_req.prompt_override or chat_session.prompt_override
)
if new_msg_req.persona_override_config:
prompt_config = PromptConfig(
system_prompt=new_msg_req.persona_override_config.prompts[
0
].system_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[
0
].task_prompt,
datetime_aware=new_msg_req.persona_override_config.prompts[
0
].datetime_aware,
@ -864,7 +910,9 @@ def stream_chat_message_objects(
)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
urls=[
img.url for img in img_generation_response if img.url
],
base64_files=[
img.image_data
for img in img_generation_response
@ -890,7 +938,9 @@ def stream_chat_message_objects(
)
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
custom_tool_response = cast(
CustomToolCallSummary, packet.response
)
if (
custom_tool_response.response_type == "image"
@ -903,7 +953,8 @@ def stream_chat_message_objects(
id=str(file_id),
type=(
ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
if custom_tool_response.response_type
== "image"
else ChatFileType.CSV
),
)
@ -967,7 +1018,9 @@ def stream_chat_message_objects(
llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
yield StreamingError(
error=client_error_msg, stack_trace=stack_trace
)
db_session.rollback()
return

View File

@ -1,9 +1,12 @@
import csv
import json
import os
import string
from collections.abc import Callable
from collections.abc import Mapping
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
@ -285,15 +288,62 @@ def parallel_visit_api_retrieval(
return inference_chunks
@retry(tries=3, delay=1, backoff=2)
def _append_ranking_stats_to_csv(
ranking_stats: list[tuple[str, float, str, str, str, float]],
csv_path: str = "/tmp/ranking_stats.csv",
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(
["query_alpha", "query", "hit_position", "document_id", "relevance"]
)
# Write the ranking stats
for cat, query_alpha, query, hit_pos, doc_chunk_id, relevance in ranking_stats:
writer.writerow([cat, query_alpha, query, hit_pos, doc_chunk_id, relevance])
logger.debug(f"Appended {len(ranking_stats)} ranking stats to {csv_path}")
@retry(tries=1, delay=1, backoff=2)
def query_vespa(
query_params: Mapping[str, str | int | float]
) -> list[InferenceChunkUncleaned]:
if "query" in query_params and not cast(str, query_params["query"]).strip():
raise ValueError("No/empty query received")
ranking_stats: list[tuple[str, float, str, str, str, float]] = []
search_time = 0.0
alphas: list[float] = [0.4, 0.7, 1.0]
for query_alpha in alphas:
date_time_start = datetime.now()
# Create a mutable copy of the query_params
mutable_params = dict(query_params)
# Now we can modify it without mypy errors
mutable_params["input.query(alpha)"] = query_alpha
params = dict(
**query_params,
**mutable_params,
**{
"presentation.timing": True,
}
@ -342,10 +392,43 @@ def query_vespa(
f"fetch this document"
)
for hit_pos, hit in enumerate(hits):
ranking_stats.append(
(
"Retrieval",
query_alpha,
cast(str, mutable_params["query"]),
str(hit_pos),
hit["fields"].get("document_id", "")
+ "__"
+ str(hit["fields"].get("chunk_id", "")),
hit.get("relevance", 0),
)
)
date_time_end = datetime.now()
search_time += (date_time_end - date_time_start).microseconds / 1000000
avg_search_time = search_time / len(alphas)
ranking_stats.append(
(
"Timing",
query_alpha,
cast(str, query_params["query"]).strip(),
"",
"Avg:",
avg_search_time,
)
)
if ranking_stats:
_append_ranking_stats_to_csv(ranking_stats)
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
# Good Debugging Spot
logger.info(f"Search done for all alphs - avg timing: {avg_search_time}")
return inference_chunks