mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-12 21:00:07 +02:00
239 lines
9.2 KiB
Python
239 lines
9.2 KiB
Python
import csv
|
|
import json
|
|
import os
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
|
main_graph_builder,
|
|
)
|
|
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
|
main_graph_builder as main_graph_builder_a,
|
|
)
|
|
from onyx.agents.agent_search.deep_search.main.states import (
|
|
MainInput as MainInput_a,
|
|
)
|
|
from onyx.agents.agent_search.run_graph import run_basic_graph
|
|
from onyx.agents.agent_search.run_graph import run_main_graph
|
|
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
|
from onyx.chat.models import AgentAnswerPiece
|
|
from onyx.chat.models import OnyxAnswerPiece
|
|
from onyx.chat.models import RefinedAnswerImprovement
|
|
from onyx.chat.models import StreamStopInfo
|
|
from onyx.chat.models import StreamType
|
|
from onyx.chat.models import SubQuestionPiece
|
|
from onyx.context.search.models import SearchRequest
|
|
from onyx.db.engine import get_session_context_manager
|
|
from onyx.llm.factory import get_default_llms
|
|
from onyx.tools.force import ForceUseTool
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
cwd = os.getcwd()
|
|
CONFIG = yaml.safe_load(
|
|
open(f"{cwd}/backend/tests/regression/answer_quality/search_test_config.yaml")
|
|
)
|
|
INPUT_DIR = CONFIG["agent_test_input_folder"]
|
|
OUTPUT_DIR = CONFIG["agent_test_output_folder"]
|
|
|
|
|
|
graph = main_graph_builder(test_mode=True)
|
|
compiled_graph = graph.compile()
|
|
primary_llm, fast_llm = get_default_llms()
|
|
|
|
# create a local json test data file and use it here
|
|
|
|
|
|
input_file_object = open(
|
|
f"{INPUT_DIR}/agent_test_data.json",
|
|
)
|
|
output_file = f"{OUTPUT_DIR}/agent_test_output.csv"
|
|
|
|
csv_output_data: list[list[str]] = []
|
|
|
|
test_data = json.load(input_file_object)
|
|
example_data = test_data["examples"]
|
|
example_ids = test_data["example_ids"]
|
|
|
|
failed_example_ids: list[int] = []
|
|
|
|
with get_session_context_manager() as db_session:
|
|
output_data: dict[str, Any] = {}
|
|
|
|
primary_llm, fast_llm = get_default_llms()
|
|
|
|
for example in example_data:
|
|
query_start_time: datetime = datetime.now()
|
|
example_id: int = int(example.get("id"))
|
|
example_question: str = example.get("question")
|
|
if not example_question or not example_id:
|
|
continue
|
|
if len(example_ids) > 0 and example_id not in example_ids:
|
|
continue
|
|
|
|
logger.info(f"{query_start_time} -- Processing example {example_id}")
|
|
|
|
try:
|
|
example_question = example["question"]
|
|
target_sub_questions = example.get("target_sub_questions", [])
|
|
num_target_sub_questions = len(target_sub_questions)
|
|
search_request = SearchRequest(query=example_question)
|
|
|
|
initial_answer_duration: timedelta | None = None
|
|
refined_answer_duration: timedelta | None = None
|
|
base_answer_duration: timedelta | None = None
|
|
|
|
logger.debug("\n\nTEST QUERY START\n\n")
|
|
|
|
graph = main_graph_builder_a()
|
|
compiled_graph = graph.compile()
|
|
query_end_time = datetime.now()
|
|
|
|
search_request = SearchRequest(
|
|
# query="what can you do with gitlab?",
|
|
# query="What are the guiding principles behind the development of cockroachDB",
|
|
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
|
# query="When was Washington born?",
|
|
# query="What is Onyx?",
|
|
# query="What is the difference between astronomy and astrology?",
|
|
query=example_question,
|
|
)
|
|
|
|
answer_tokens: dict[str, list[str]] = defaultdict(list)
|
|
|
|
with get_session_context_manager() as db_session:
|
|
config = get_test_config(
|
|
db_session, primary_llm, fast_llm, search_request
|
|
)
|
|
assert (
|
|
config.persistence is not None
|
|
), "set a chat session id to run this test"
|
|
|
|
# search_request.persona = get_persona_by_id(1, None, db_session)
|
|
# config.perform_initial_search_path_decision = False
|
|
config.behavior.perform_initial_search_decomposition = True
|
|
input = MainInput_a()
|
|
|
|
# Base Flow
|
|
base_flow_start_time: datetime = datetime.now()
|
|
for output in run_basic_graph(config):
|
|
if isinstance(output, OnyxAnswerPiece):
|
|
answer_tokens["base_answer"].append(output.answer_piece or "")
|
|
|
|
output_data["base_answer"] = "".join(answer_tokens["base_answer"])
|
|
output_data["base_answer_duration"] = (
|
|
datetime.now() - base_flow_start_time
|
|
)
|
|
|
|
# Agent Flow
|
|
agent_flow_start_time: datetime = datetime.now()
|
|
config = get_test_config(
|
|
db_session,
|
|
primary_llm,
|
|
fast_llm,
|
|
search_request,
|
|
use_agentic_search=True,
|
|
)
|
|
|
|
config.tooling.force_use_tool = ForceUseTool(
|
|
force_use=True, tool_name=SearchTool._NAME
|
|
)
|
|
|
|
tool_responses: list = []
|
|
|
|
sub_question_dict_tokens: dict[int, dict[int, str]] = defaultdict(
|
|
lambda: defaultdict(str)
|
|
)
|
|
|
|
for output in run_main_graph(config):
|
|
if isinstance(output, AgentAnswerPiece):
|
|
if output.level == 0 and output.level_question_num == 0:
|
|
answer_tokens["initial"].append(output.answer_piece)
|
|
elif output.level == 1 and output.level_question_num == 0:
|
|
answer_tokens["refined"].append(output.answer_piece)
|
|
elif isinstance(output, SubQuestionPiece):
|
|
if (
|
|
output.level is not None
|
|
and output.level_question_num is not None
|
|
):
|
|
sub_question_dict_tokens[output.level][
|
|
output.level_question_num
|
|
] += output.sub_question
|
|
elif isinstance(output, StreamStopInfo):
|
|
if (
|
|
output.stream_type == StreamType.MAIN_ANSWER
|
|
and output.level == 0
|
|
):
|
|
initial_answer_duration = (
|
|
datetime.now() - agent_flow_start_time
|
|
)
|
|
elif isinstance(output, RefinedAnswerImprovement):
|
|
output_data["refined_answer_improves_on_initial_answer"] = str(
|
|
output.refined_answer_improvement
|
|
)
|
|
|
|
refined_answer_duration = datetime.now() - agent_flow_start_time
|
|
|
|
output_data["example_id"] = example_id
|
|
output_data["question"] = example_question
|
|
output_data["initial_answer"] = "".join(answer_tokens["initial"])
|
|
output_data["refined_answer"] = "".join(answer_tokens["refined"])
|
|
output_data["initial_answer_duration"] = initial_answer_duration or ""
|
|
output_data["refined_answer_duration"] = refined_answer_duration
|
|
|
|
output_data["initial_sub_questions"] = "\n---\n".join(
|
|
[x for x in sub_question_dict_tokens[0].values()]
|
|
)
|
|
output_data["refined_sub_questions"] = "\n---\n".join(
|
|
[x for x in sub_question_dict_tokens[1].values()]
|
|
)
|
|
|
|
csv_output_data.append(
|
|
[
|
|
str(example_id),
|
|
example_question,
|
|
output_data["base_answer"],
|
|
output_data["base_answer_duration"],
|
|
output_data["initial_sub_questions"],
|
|
output_data["initial_answer"],
|
|
output_data["initial_answer_duration"],
|
|
output_data["refined_sub_questions"],
|
|
output_data["refined_answer"],
|
|
output_data["refined_answer_duration"],
|
|
output_data["refined_answer_improves_on_initial_answer"],
|
|
]
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error processing example {example_id}: {e}")
|
|
failed_example_ids.append(example_id)
|
|
continue
|
|
|
|
|
|
with open(output_file, "w", newline="") as csvfile:
|
|
writer = csv.writer(csvfile, delimiter="\t")
|
|
writer.writerow(
|
|
[
|
|
"example_id",
|
|
"question",
|
|
"base_answer",
|
|
"base_answer_duration",
|
|
"initial_sub_questions",
|
|
"initial_answer",
|
|
"initial_answer_duration",
|
|
"refined_sub_questions",
|
|
"refined_answer",
|
|
"refined_answer_duration",
|
|
"refined_answer_improves_on_initial_answer",
|
|
]
|
|
)
|
|
writer.writerows(csv_output_data)
|
|
|
|
print("DONE")
|