mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 03:01:59 +01:00
Compare commits
3 Commits
f45798b5dd
...
625936306f
Author | SHA1 | Date | |
---|---|---|---|
|
625936306f | ||
|
ab11bf6552 | ||
|
83d5b3b503 |
@ -357,6 +357,38 @@ def stream_chat_message_objects(
|
|||||||
|
|
||||||
llm: LLM
|
llm: LLM
|
||||||
|
|
||||||
|
test_questions = [
|
||||||
|
"big bang vs steady state theory",
|
||||||
|
"astronomy",
|
||||||
|
"trace energy momentum tensor conformal field theory",
|
||||||
|
"evidence Big Bang",
|
||||||
|
"Neil Armstrong play tennis moon",
|
||||||
|
"current temperature Hawaii New York Munich",
|
||||||
|
"win quadradoodle",
|
||||||
|
"best practices coding Java",
|
||||||
|
"classes related software engineering",
|
||||||
|
"current temperature Munich",
|
||||||
|
"what is the most important concept in biology",
|
||||||
|
"subfields of finance",
|
||||||
|
"what is the overlap between finance and economics",
|
||||||
|
"effects taking vitamin c pills vs eating veggies health outcomes",
|
||||||
|
"professions people good math",
|
||||||
|
"biomedical engineers design cutting-edge medical equipment important skill set",
|
||||||
|
"How do biomedical engineers design cutting-edge medical equipment? And what is the most important skill set?",
|
||||||
|
"average power output US nuclear power plant",
|
||||||
|
"typical power range small modular reactors",
|
||||||
|
"SMRs power industry",
|
||||||
|
"best use case Onyx AI company",
|
||||||
|
"techniques calculate square root",
|
||||||
|
"daily vitamin C requirement adult women",
|
||||||
|
"boil ocean",
|
||||||
|
"best soccer player ever",
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_question_num, test_question in enumerate(test_questions):
|
||||||
|
logger.info(
|
||||||
|
f"------- Running test question {test_question_num + 1} of {len(test_questions)}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
user_id = user.id if user is not None else None
|
user_id = user.id if user is not None else None
|
||||||
|
|
||||||
@ -366,7 +398,8 @@ def stream_chat_message_objects(
|
|||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
message_text = new_msg_req.message
|
# message_text = new_msg_req.message
|
||||||
|
message_text = test_question
|
||||||
chat_session_id = new_msg_req.chat_session_id
|
chat_session_id = new_msg_req.chat_session_id
|
||||||
parent_id = new_msg_req.parent_message_id
|
parent_id = new_msg_req.parent_message_id
|
||||||
reference_doc_ids = new_msg_req.search_doc_ids
|
reference_doc_ids = new_msg_req.search_doc_ids
|
||||||
@ -375,7 +408,10 @@ def stream_chat_message_objects(
|
|||||||
|
|
||||||
# permanent "log" store, used primarily for debugging
|
# permanent "log" store, used primarily for debugging
|
||||||
long_term_logger = LongTermLogger(
|
long_term_logger = LongTermLogger(
|
||||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
metadata={
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"chat_session_id": str(chat_session_id),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if alternate_assistant_id is not None:
|
if alternate_assistant_id is not None:
|
||||||
@ -536,7 +572,9 @@ def stream_chat_message_objects(
|
|||||||
history_msgs, new_msg_req.file_descriptors, db_session
|
history_msgs, new_msg_req.file_descriptors, db_session
|
||||||
)
|
)
|
||||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
latest_query_files = [
|
||||||
|
file for file in files if file.file_id in req_file_ids
|
||||||
|
]
|
||||||
|
|
||||||
if user_message:
|
if user_message:
|
||||||
attach_files_to_chat_message(
|
attach_files_to_chat_message(
|
||||||
@ -609,7 +647,9 @@ def stream_chat_message_objects(
|
|||||||
)
|
)
|
||||||
|
|
||||||
overridden_model = (
|
overridden_model = (
|
||||||
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
new_msg_req.llm_override.model_version
|
||||||
|
if new_msg_req.llm_override
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cannot determine these without the LLM step or breaking out early
|
# Cannot determine these without the LLM step or breaking out early
|
||||||
@ -621,7 +661,9 @@ def stream_chat_message_objects(
|
|||||||
# the latest. If we're creating a new assistant message, then the parent
|
# the latest. If we're creating a new assistant message, then the parent
|
||||||
# should be the latest message (latest user message)
|
# should be the latest message (latest user message)
|
||||||
parent_message=(
|
parent_message=(
|
||||||
final_msg if existing_assistant_message_id is None else parent_message
|
final_msg
|
||||||
|
if existing_assistant_message_id is None
|
||||||
|
else parent_message
|
||||||
),
|
),
|
||||||
prompt_id=prompt_id,
|
prompt_id=prompt_id,
|
||||||
overridden_model=overridden_model,
|
overridden_model=overridden_model,
|
||||||
@ -638,13 +680,17 @@ def stream_chat_message_objects(
|
|||||||
is_agentic=new_msg_req.use_agentic_search,
|
is_agentic=new_msg_req.use_agentic_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
prompt_override = (
|
||||||
|
new_msg_req.prompt_override or chat_session.prompt_override
|
||||||
|
)
|
||||||
if new_msg_req.persona_override_config:
|
if new_msg_req.persona_override_config:
|
||||||
prompt_config = PromptConfig(
|
prompt_config = PromptConfig(
|
||||||
system_prompt=new_msg_req.persona_override_config.prompts[
|
system_prompt=new_msg_req.persona_override_config.prompts[
|
||||||
0
|
0
|
||||||
].system_prompt,
|
].system_prompt,
|
||||||
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
|
task_prompt=new_msg_req.persona_override_config.prompts[
|
||||||
|
0
|
||||||
|
].task_prompt,
|
||||||
datetime_aware=new_msg_req.persona_override_config.prompts[
|
datetime_aware=new_msg_req.persona_override_config.prompts[
|
||||||
0
|
0
|
||||||
].datetime_aware,
|
].datetime_aware,
|
||||||
@ -864,7 +910,9 @@ def stream_chat_message_objects(
|
|||||||
)
|
)
|
||||||
|
|
||||||
file_ids = save_files(
|
file_ids = save_files(
|
||||||
urls=[img.url for img in img_generation_response if img.url],
|
urls=[
|
||||||
|
img.url for img in img_generation_response if img.url
|
||||||
|
],
|
||||||
base64_files=[
|
base64_files=[
|
||||||
img.image_data
|
img.image_data
|
||||||
for img in img_generation_response
|
for img in img_generation_response
|
||||||
@ -890,7 +938,9 @@ def stream_chat_message_objects(
|
|||||||
)
|
)
|
||||||
yield info.qa_docs_response
|
yield info.qa_docs_response
|
||||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
custom_tool_response = cast(
|
||||||
|
CustomToolCallSummary, packet.response
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
custom_tool_response.response_type == "image"
|
custom_tool_response.response_type == "image"
|
||||||
@ -903,7 +953,8 @@ def stream_chat_message_objects(
|
|||||||
id=str(file_id),
|
id=str(file_id),
|
||||||
type=(
|
type=(
|
||||||
ChatFileType.IMAGE
|
ChatFileType.IMAGE
|
||||||
if custom_tool_response.response_type == "image"
|
if custom_tool_response.response_type
|
||||||
|
== "image"
|
||||||
else ChatFileType.CSV
|
else ChatFileType.CSV
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -967,7 +1018,9 @@ def stream_chat_message_objects(
|
|||||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||||
)
|
)
|
||||||
|
|
||||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
yield StreamingError(
|
||||||
|
error=client_error_msg, stack_trace=stack_trace
|
||||||
|
)
|
||||||
|
|
||||||
db_session.rollback()
|
db_session.rollback()
|
||||||
return
|
return
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
import csv
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import string
|
import string
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@ -285,15 +288,62 @@ def parallel_visit_api_retrieval(
|
|||||||
return inference_chunks
|
return inference_chunks
|
||||||
|
|
||||||
|
|
||||||
@retry(tries=3, delay=1, backoff=2)
|
def _append_ranking_stats_to_csv(
|
||||||
|
ranking_stats: list[tuple[str, float, str, str, str, float]],
|
||||||
|
csv_path: str = "/tmp/ranking_stats.csv",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Append ranking statistics to a CSV file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ranking_stats: List of tuples containing (query, hit_position, document_id)
|
||||||
|
csv_path: Path to the CSV file to append to
|
||||||
|
"""
|
||||||
|
file_exists = os.path.isfile(csv_path)
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
csv_dir = os.path.dirname(csv_path)
|
||||||
|
if csv_dir and not os.path.exists(csv_dir):
|
||||||
|
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||||
|
writer = csv.writer(file)
|
||||||
|
|
||||||
|
# Write header if file is new
|
||||||
|
if not file_exists:
|
||||||
|
writer.writerow(
|
||||||
|
["query_alpha", "query", "hit_position", "document_id", "relevance"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write the ranking stats
|
||||||
|
for cat, query_alpha, query, hit_pos, doc_chunk_id, relevance in ranking_stats:
|
||||||
|
writer.writerow([cat, query_alpha, query, hit_pos, doc_chunk_id, relevance])
|
||||||
|
|
||||||
|
logger.debug(f"Appended {len(ranking_stats)} ranking stats to {csv_path}")
|
||||||
|
|
||||||
|
|
||||||
|
@retry(tries=1, delay=1, backoff=2)
|
||||||
def query_vespa(
|
def query_vespa(
|
||||||
query_params: Mapping[str, str | int | float]
|
query_params: Mapping[str, str | int | float]
|
||||||
) -> list[InferenceChunkUncleaned]:
|
) -> list[InferenceChunkUncleaned]:
|
||||||
if "query" in query_params and not cast(str, query_params["query"]).strip():
|
if "query" in query_params and not cast(str, query_params["query"]).strip():
|
||||||
raise ValueError("No/empty query received")
|
raise ValueError("No/empty query received")
|
||||||
|
|
||||||
|
ranking_stats: list[tuple[str, float, str, str, str, float]] = []
|
||||||
|
|
||||||
|
search_time = 0.0
|
||||||
|
|
||||||
|
alphas: list[float] = [0.4, 0.7, 1.0]
|
||||||
|
for query_alpha in alphas:
|
||||||
|
date_time_start = datetime.now()
|
||||||
|
|
||||||
|
# Create a mutable copy of the query_params
|
||||||
|
mutable_params = dict(query_params)
|
||||||
|
# Now we can modify it without mypy errors
|
||||||
|
mutable_params["input.query(alpha)"] = query_alpha
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
**query_params,
|
**mutable_params,
|
||||||
**{
|
**{
|
||||||
"presentation.timing": True,
|
"presentation.timing": True,
|
||||||
}
|
}
|
||||||
@ -342,10 +392,43 @@ def query_vespa(
|
|||||||
f"fetch this document"
|
f"fetch this document"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for hit_pos, hit in enumerate(hits):
|
||||||
|
ranking_stats.append(
|
||||||
|
(
|
||||||
|
"Retrieval",
|
||||||
|
query_alpha,
|
||||||
|
cast(str, mutable_params["query"]),
|
||||||
|
str(hit_pos),
|
||||||
|
hit["fields"].get("document_id", "")
|
||||||
|
+ "__"
|
||||||
|
+ str(hit["fields"].get("chunk_id", "")),
|
||||||
|
hit.get("relevance", 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
date_time_end = datetime.now()
|
||||||
|
search_time += (date_time_end - date_time_start).microseconds / 1000000
|
||||||
|
|
||||||
|
avg_search_time = search_time / len(alphas)
|
||||||
|
ranking_stats.append(
|
||||||
|
(
|
||||||
|
"Timing",
|
||||||
|
query_alpha,
|
||||||
|
cast(str, query_params["query"]).strip(),
|
||||||
|
"",
|
||||||
|
"Avg:",
|
||||||
|
avg_search_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if ranking_stats:
|
||||||
|
_append_ranking_stats_to_csv(ranking_stats)
|
||||||
|
|
||||||
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
|
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
|
||||||
|
|
||||||
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
|
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
|
||||||
# Good Debugging Spot
|
# Good Debugging Spot
|
||||||
|
logger.info(f"Search done for all alphs - avg timing: {avg_search_time}")
|
||||||
return inference_chunks
|
return inference_chunks
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user