major Agent Search Updates (#3994)

This commit is contained in:
joachim-danswer
2025-02-14 11:40:21 -08:00
committed by GitHub
parent ec78f78f3c
commit 6687d5d499
36 changed files with 2115 additions and 431 deletions

View File

@ -1,18 +1,39 @@
import csv
import datetime
import json
import os
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from typing import Any
import yaml
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder,
)
from onyx.agents.agent_search.deep_search.main.states import MainInput
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput_a,
)
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
logger = setup_logger()
cwd = os.getcwd()
@ -35,95 +56,183 @@ input_file_object = open(
)
output_file = f"{OUTPUT_DIR}/agent_test_output.csv"
csv_output_data: list[list[str]] = []
test_data = json.load(input_file_object)
example_data = test_data["examples"]
example_ids = test_data["example_ids"]
failed_example_ids: list[int] = []
with get_session_context_manager() as db_session:
output_data = []
output_data: dict[str, Any] = {}
primary_llm, fast_llm = get_default_llms()
for example in example_data:
example_id = example["id"]
query_start_time: datetime = datetime.now()
example_id: int = int(example.get("id"))
example_question: str = example.get("question")
if not example_question or not example_id:
continue
if len(example_ids) > 0 and example_id not in example_ids:
continue
example_question = example["question"]
target_sub_questions = example.get("target_sub_questions", [])
num_target_sub_questions = len(target_sub_questions)
search_request = SearchRequest(query=example_question)
logger.info(f"{query_start_time} -- Processing example {example_id}")
config, search_tool = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_request=search_request,
)
try:
example_question = example["question"]
target_sub_questions = example.get("target_sub_questions", [])
num_target_sub_questions = len(target_sub_questions)
search_request = SearchRequest(query=example_question)
inputs = MainInput()
initial_answer_duration: timedelta | None = None
refined_answer_duration: timedelta | None = None
base_answer_duration: timedelta | None = None
start_time = datetime.datetime.now()
logger.debug("\n\nTEST QUERY START\n\n")
question_result = compiled_graph.invoke(
input=inputs, config={"metadata": {"config": config}}
)
end_time = datetime.datetime.now()
graph = main_graph_builder_a()
compiled_graph = graph.compile()
query_end_time = datetime.now()
duration = end_time - start_time
if num_target_sub_questions > 0:
chunk_expansion_ratio = (
question_result["initial_agent_stats"]
.get("agent_effectiveness", {})
.get("utilized_chunk_ratio", None)
search_request = SearchRequest(
# query="what can you do with gitlab?",
# query="What are the guiding principles behind the development of cockroachDB",
# query="What are the temperatures in Munich, Hawaii, and New York?",
# query="When was Washington born?",
# query="What is Onyx?",
# query="What is the difference between astronomy and astrology?",
query=example_question,
)
support_effectiveness_ratio = (
question_result["initial_agent_stats"]
.get("agent_effectiveness", {})
.get("support_ratio", None)
)
else:
chunk_expansion_ratio = None
support_effectiveness_ratio = None
generated_sub_questions = question_result.get("generated_sub_questions", [])
num_generated_sub_questions = len(generated_sub_questions)
base_answer = question_result["initial_base_answer"].split("==")[-1]
agent_answer = question_result["initial_answer"].split("==")[-1]
answer_tokens: dict[str, list[str]] = defaultdict(list)
output_point = {
"example_id": example_id,
"question": example_question,
"duration": duration,
"target_sub_questions": target_sub_questions,
"generated_sub_questions": generated_sub_questions,
"num_target_sub_questions": num_target_sub_questions,
"num_generated_sub_questions": num_generated_sub_questions,
"chunk_expansion_ratio": chunk_expansion_ratio,
"support_effectiveness_ratio": support_effectiveness_ratio,
"base_answer": base_answer,
"agent_answer": agent_answer,
}
with get_session_context_manager() as db_session:
config = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
assert (
config.persistence is not None
), "set a chat session id to run this test"
output_data.append(output_point)
# search_request.persona = get_persona_by_id(1, None, db_session)
# config.perform_initial_search_path_decision = False
config.behavior.perform_initial_search_decomposition = True
input = MainInput_a()
# Base Flow
base_flow_start_time: datetime = datetime.now()
for output in run_basic_graph(config):
if isinstance(output, OnyxAnswerPiece):
answer_tokens["base_answer"].append(output.answer_piece or "")
output_data["base_answer"] = "".join(answer_tokens["base_answer"])
output_data["base_answer_duration"] = (
datetime.now() - base_flow_start_time
)
# Agent Flow
agent_flow_start_time: datetime = datetime.now()
config = get_test_config(
db_session,
primary_llm,
fast_llm,
search_request,
use_agentic_search=True,
)
config.tooling.force_use_tool = ForceUseTool(
force_use=True, tool_name=SearchTool._NAME
)
tool_responses: list = []
sub_question_dict_tokens: dict[int, dict[int, str]] = defaultdict(
lambda: defaultdict(str)
)
for output in run_main_graph(config):
if isinstance(output, AgentAnswerPiece):
if output.level == 0 and output.level_question_num == 0:
answer_tokens["initial"].append(output.answer_piece)
elif output.level == 1 and output.level_question_num == 0:
answer_tokens["refined"].append(output.answer_piece)
elif isinstance(output, SubQuestionPiece):
if (
output.level is not None
and output.level_question_num is not None
):
sub_question_dict_tokens[output.level][
output.level_question_num
] += output.sub_question
elif isinstance(output, StreamStopInfo):
if (
output.stream_type == StreamType.MAIN_ANSWER
and output.level == 0
):
initial_answer_duration = (
datetime.now() - agent_flow_start_time
)
elif isinstance(output, RefinedAnswerImprovement):
output_data["refined_answer_improves_on_initial_answer"] = str(
output.refined_answer_improvement
)
refined_answer_duration = datetime.now() - agent_flow_start_time
output_data["example_id"] = example_id
output_data["question"] = example_question
output_data["initial_answer"] = "".join(answer_tokens["initial"])
output_data["refined_answer"] = "".join(answer_tokens["refined"])
output_data["initial_answer_duration"] = initial_answer_duration or ""
output_data["refined_answer_duration"] = refined_answer_duration
output_data["initial_sub_questions"] = "\n---\n".join(
[x for x in sub_question_dict_tokens[0].values()]
)
output_data["refined_sub_questions"] = "\n---\n".join(
[x for x in sub_question_dict_tokens[1].values()]
)
csv_output_data.append(
[
str(example_id),
example_question,
output_data["base_answer"],
output_data["base_answer_duration"],
output_data["initial_sub_questions"],
output_data["initial_answer"],
output_data["initial_answer_duration"],
output_data["refined_sub_questions"],
output_data["refined_answer"],
output_data["refined_answer_duration"],
output_data["refined_answer_improves_on_initial_answer"],
]
)
except Exception as e:
logger.error(f"Error processing example {example_id}: {e}")
failed_example_ids.append(example_id)
continue
with open(output_file, "w", newline="") as csvfile:
fieldnames = [
"example_id",
"question",
"duration",
"target_sub_questions",
"generated_sub_questions",
"num_target_sub_questions",
"num_generated_sub_questions",
"chunk_expansion_ratio",
"support_effectiveness_ratio",
"base_answer",
"agent_answer",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter="\t")
writer.writeheader()
writer.writerows(output_data)
writer = csv.writer(csvfile, delimiter="\t")
writer.writerow(
[
"example_id",
"question",
"base_answer",
"base_answer_duration",
"initial_sub_questions",
"initial_answer",
"initial_answer_duration",
"refined_sub_questions",
"refined_answer",
"refined_answer_duration",
"refined_answer_improves_on_initial_answer",
]
)
writer.writerows(csv_output_data)
print("DONE")