2025-02-14 19:40:21 +00:00

239 lines
9.2 KiB
Python

import csv
import json
import os
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from typing import Any
import yaml
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder,
)
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput_a,
)
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
cwd = os.getcwd()
CONFIG = yaml.safe_load(
open(f"{cwd}/backend/tests/regression/answer_quality/search_test_config.yaml")
)
INPUT_DIR = CONFIG["agent_test_input_folder"]
OUTPUT_DIR = CONFIG["agent_test_output_folder"]
graph = main_graph_builder(test_mode=True)
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
# create a local json test data file and use it here
input_file_object = open(
f"{INPUT_DIR}/agent_test_data.json",
)
output_file = f"{OUTPUT_DIR}/agent_test_output.csv"
csv_output_data: list[list[str]] = []
test_data = json.load(input_file_object)
example_data = test_data["examples"]
example_ids = test_data["example_ids"]
failed_example_ids: list[int] = []
with get_session_context_manager() as db_session:
output_data: dict[str, Any] = {}
primary_llm, fast_llm = get_default_llms()
for example in example_data:
query_start_time: datetime = datetime.now()
example_id: int = int(example.get("id"))
example_question: str = example.get("question")
if not example_question or not example_id:
continue
if len(example_ids) > 0 and example_id not in example_ids:
continue
logger.info(f"{query_start_time} -- Processing example {example_id}")
try:
example_question = example["question"]
target_sub_questions = example.get("target_sub_questions", [])
num_target_sub_questions = len(target_sub_questions)
search_request = SearchRequest(query=example_question)
initial_answer_duration: timedelta | None = None
refined_answer_duration: timedelta | None = None
base_answer_duration: timedelta | None = None
logger.debug("\n\nTEST QUERY START\n\n")
graph = main_graph_builder_a()
compiled_graph = graph.compile()
query_end_time = datetime.now()
search_request = SearchRequest(
# query="what can you do with gitlab?",
# query="What are the guiding principles behind the development of cockroachDB",
# query="What are the temperatures in Munich, Hawaii, and New York?",
# query="When was Washington born?",
# query="What is Onyx?",
# query="What is the difference between astronomy and astrology?",
query=example_question,
)
answer_tokens: dict[str, list[str]] = defaultdict(list)
with get_session_context_manager() as db_session:
config = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
assert (
config.persistence is not None
), "set a chat session id to run this test"
# search_request.persona = get_persona_by_id(1, None, db_session)
# config.perform_initial_search_path_decision = False
config.behavior.perform_initial_search_decomposition = True
input = MainInput_a()
# Base Flow
base_flow_start_time: datetime = datetime.now()
for output in run_basic_graph(config):
if isinstance(output, OnyxAnswerPiece):
answer_tokens["base_answer"].append(output.answer_piece or "")
output_data["base_answer"] = "".join(answer_tokens["base_answer"])
output_data["base_answer_duration"] = (
datetime.now() - base_flow_start_time
)
# Agent Flow
agent_flow_start_time: datetime = datetime.now()
config = get_test_config(
db_session,
primary_llm,
fast_llm,
search_request,
use_agentic_search=True,
)
config.tooling.force_use_tool = ForceUseTool(
force_use=True, tool_name=SearchTool._NAME
)
tool_responses: list = []
sub_question_dict_tokens: dict[int, dict[int, str]] = defaultdict(
lambda: defaultdict(str)
)
for output in run_main_graph(config):
if isinstance(output, AgentAnswerPiece):
if output.level == 0 and output.level_question_num == 0:
answer_tokens["initial"].append(output.answer_piece)
elif output.level == 1 and output.level_question_num == 0:
answer_tokens["refined"].append(output.answer_piece)
elif isinstance(output, SubQuestionPiece):
if (
output.level is not None
and output.level_question_num is not None
):
sub_question_dict_tokens[output.level][
output.level_question_num
] += output.sub_question
elif isinstance(output, StreamStopInfo):
if (
output.stream_type == StreamType.MAIN_ANSWER
and output.level == 0
):
initial_answer_duration = (
datetime.now() - agent_flow_start_time
)
elif isinstance(output, RefinedAnswerImprovement):
output_data["refined_answer_improves_on_initial_answer"] = str(
output.refined_answer_improvement
)
refined_answer_duration = datetime.now() - agent_flow_start_time
output_data["example_id"] = example_id
output_data["question"] = example_question
output_data["initial_answer"] = "".join(answer_tokens["initial"])
output_data["refined_answer"] = "".join(answer_tokens["refined"])
output_data["initial_answer_duration"] = initial_answer_duration or ""
output_data["refined_answer_duration"] = refined_answer_duration
output_data["initial_sub_questions"] = "\n---\n".join(
[x for x in sub_question_dict_tokens[0].values()]
)
output_data["refined_sub_questions"] = "\n---\n".join(
[x for x in sub_question_dict_tokens[1].values()]
)
csv_output_data.append(
[
str(example_id),
example_question,
output_data["base_answer"],
output_data["base_answer_duration"],
output_data["initial_sub_questions"],
output_data["initial_answer"],
output_data["initial_answer_duration"],
output_data["refined_sub_questions"],
output_data["refined_answer"],
output_data["refined_answer_duration"],
output_data["refined_answer_improves_on_initial_answer"],
]
)
except Exception as e:
logger.error(f"Error processing example {example_id}: {e}")
failed_example_ids.append(example_id)
continue
with open(output_file, "w", newline="") as csvfile:
writer = csv.writer(csvfile, delimiter="\t")
writer.writerow(
[
"example_id",
"question",
"base_answer",
"base_answer_duration",
"initial_sub_questions",
"initial_answer",
"initial_answer_duration",
"refined_sub_questions",
"refined_answer",
"refined_answer_duration",
"refined_answer_improves_on_initial_answer",
]
)
writer.writerows(csv_output_data)
print("DONE")