mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 13:22:42 +01:00
Compare commits
3 Commits
f45798b5dd
...
625936306f
Author | SHA1 | Date | |
---|---|---|---|
|
625936306f | ||
|
ab11bf6552 | ||
|
83d5b3b503 |
@ -357,6 +357,38 @@ def stream_chat_message_objects(
|
||||
|
||||
llm: LLM
|
||||
|
||||
test_questions = [
|
||||
"big bang vs steady state theory",
|
||||
"astronomy",
|
||||
"trace energy momentum tensor conformal field theory",
|
||||
"evidence Big Bang",
|
||||
"Neil Armstrong play tennis moon",
|
||||
"current temperature Hawaii New York Munich",
|
||||
"win quadradoodle",
|
||||
"best practices coding Java",
|
||||
"classes related software engineering",
|
||||
"current temperature Munich",
|
||||
"what is the most important concept in biology",
|
||||
"subfields of finance",
|
||||
"what is the overlap between finance and economics",
|
||||
"effects taking vitamin c pills vs eating veggies health outcomes",
|
||||
"professions people good math",
|
||||
"biomedical engineers design cutting-edge medical equipment important skill set",
|
||||
"How do biomedical engineers design cutting-edge medical equipment? And what is the most important skill set?",
|
||||
"average power output US nuclear power plant",
|
||||
"typical power range small modular reactors",
|
||||
"SMRs power industry",
|
||||
"best use case Onyx AI company",
|
||||
"techniques calculate square root",
|
||||
"daily vitamin C requirement adult women",
|
||||
"boil ocean",
|
||||
"best soccer player ever",
|
||||
]
|
||||
|
||||
for test_question_num, test_question in enumerate(test_questions):
|
||||
logger.info(
|
||||
f"------- Running test question {test_question_num + 1} of {len(test_questions)}"
|
||||
)
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
@ -366,7 +398,8 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
message_text = new_msg_req.message
|
||||
# message_text = new_msg_req.message
|
||||
message_text = test_question
|
||||
chat_session_id = new_msg_req.chat_session_id
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
@ -375,7 +408,10 @@ def stream_chat_message_objects(
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
metadata={
|
||||
"user_id": str(user_id),
|
||||
"chat_session_id": str(chat_session_id),
|
||||
}
|
||||
)
|
||||
|
||||
if alternate_assistant_id is not None:
|
||||
@ -536,7 +572,9 @@ def stream_chat_message_objects(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
latest_query_files = [
|
||||
file for file in files if file.file_id in req_file_ids
|
||||
]
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
@ -609,7 +647,9 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
overridden_model = (
|
||||
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
||||
new_msg_req.llm_override.model_version
|
||||
if new_msg_req.llm_override
|
||||
else None
|
||||
)
|
||||
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
@ -621,7 +661,9 @@ def stream_chat_message_objects(
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
final_msg
|
||||
if existing_assistant_message_id is None
|
||||
else parent_message
|
||||
),
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
@ -638,13 +680,17 @@ def stream_chat_message_objects(
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
prompt_override = (
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
)
|
||||
if new_msg_req.persona_override_config:
|
||||
prompt_config = PromptConfig(
|
||||
system_prompt=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].system_prompt,
|
||||
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
|
||||
task_prompt=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].task_prompt,
|
||||
datetime_aware=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].datetime_aware,
|
||||
@ -864,7 +910,9 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
urls=[
|
||||
img.url for img in img_generation_response if img.url
|
||||
],
|
||||
base64_files=[
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
@ -890,7 +938,9 @@ def stream_chat_message_objects(
|
||||
)
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
custom_tool_response = cast(
|
||||
CustomToolCallSummary, packet.response
|
||||
)
|
||||
|
||||
if (
|
||||
custom_tool_response.response_type == "image"
|
||||
@ -903,7 +953,8 @@ def stream_chat_message_objects(
|
||||
id=str(file_id),
|
||||
type=(
|
||||
ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
if custom_tool_response.response_type
|
||||
== "image"
|
||||
else ChatFileType.CSV
|
||||
),
|
||||
)
|
||||
@ -967,7 +1018,9 @@ def stream_chat_message_objects(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(
|
||||
error=client_error_msg, stack_trace=stack_trace
|
||||
)
|
||||
|
||||
db_session.rollback()
|
||||
return
|
||||
|
@ -1,9 +1,12 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@ -285,15 +288,62 @@ def parallel_visit_api_retrieval(
|
||||
return inference_chunks
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _append_ranking_stats_to_csv(
|
||||
ranking_stats: list[tuple[str, float, str, str, str, float]],
|
||||
csv_path: str = "/tmp/ranking_stats.csv",
|
||||
) -> None:
|
||||
"""
|
||||
Append ranking statistics to a CSV file.
|
||||
|
||||
Args:
|
||||
ranking_stats: List of tuples containing (query, hit_position, document_id)
|
||||
csv_path: Path to the CSV file to append to
|
||||
"""
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(
|
||||
["query_alpha", "query", "hit_position", "document_id", "relevance"]
|
||||
)
|
||||
|
||||
# Write the ranking stats
|
||||
for cat, query_alpha, query, hit_pos, doc_chunk_id, relevance in ranking_stats:
|
||||
writer.writerow([cat, query_alpha, query, hit_pos, doc_chunk_id, relevance])
|
||||
|
||||
logger.debug(f"Appended {len(ranking_stats)} ranking stats to {csv_path}")
|
||||
|
||||
|
||||
@retry(tries=1, delay=1, backoff=2)
|
||||
def query_vespa(
|
||||
query_params: Mapping[str, str | int | float]
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
if "query" in query_params and not cast(str, query_params["query"]).strip():
|
||||
raise ValueError("No/empty query received")
|
||||
|
||||
ranking_stats: list[tuple[str, float, str, str, str, float]] = []
|
||||
|
||||
search_time = 0.0
|
||||
|
||||
alphas: list[float] = [0.4, 0.7, 1.0]
|
||||
for query_alpha in alphas:
|
||||
date_time_start = datetime.now()
|
||||
|
||||
# Create a mutable copy of the query_params
|
||||
mutable_params = dict(query_params)
|
||||
# Now we can modify it without mypy errors
|
||||
mutable_params["input.query(alpha)"] = query_alpha
|
||||
|
||||
params = dict(
|
||||
**query_params,
|
||||
**mutable_params,
|
||||
**{
|
||||
"presentation.timing": True,
|
||||
}
|
||||
@ -342,10 +392,43 @@ def query_vespa(
|
||||
f"fetch this document"
|
||||
)
|
||||
|
||||
for hit_pos, hit in enumerate(hits):
|
||||
ranking_stats.append(
|
||||
(
|
||||
"Retrieval",
|
||||
query_alpha,
|
||||
cast(str, mutable_params["query"]),
|
||||
str(hit_pos),
|
||||
hit["fields"].get("document_id", "")
|
||||
+ "__"
|
||||
+ str(hit["fields"].get("chunk_id", "")),
|
||||
hit.get("relevance", 0),
|
||||
)
|
||||
)
|
||||
|
||||
date_time_end = datetime.now()
|
||||
search_time += (date_time_end - date_time_start).microseconds / 1000000
|
||||
|
||||
avg_search_time = search_time / len(alphas)
|
||||
ranking_stats.append(
|
||||
(
|
||||
"Timing",
|
||||
query_alpha,
|
||||
cast(str, query_params["query"]).strip(),
|
||||
"",
|
||||
"Avg:",
|
||||
avg_search_time,
|
||||
)
|
||||
)
|
||||
|
||||
if ranking_stats:
|
||||
_append_ranking_stats_to_csv(ranking_stats)
|
||||
|
||||
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
|
||||
|
||||
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
|
||||
# Good Debugging Spot
|
||||
logger.info(f"Search done for all alphs - avg timing: {avg_search_time}")
|
||||
return inference_chunks
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user