mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 14:12:53 +02:00
major Agent Search Updates (#3994)
This commit is contained in:
@ -1,18 +1,39 @@
|
||||
import csv
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput_a,
|
||||
)
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
cwd = os.getcwd()
|
||||
@ -35,95 +56,183 @@ input_file_object = open(
|
||||
)
|
||||
output_file = f"{OUTPUT_DIR}/agent_test_output.csv"
|
||||
|
||||
csv_output_data: list[list[str]] = []
|
||||
|
||||
test_data = json.load(input_file_object)
|
||||
example_data = test_data["examples"]
|
||||
example_ids = test_data["example_ids"]
|
||||
|
||||
failed_example_ids: list[int] = []
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
output_data = []
|
||||
output_data: dict[str, Any] = {}
|
||||
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
for example in example_data:
|
||||
example_id = example["id"]
|
||||
query_start_time: datetime = datetime.now()
|
||||
example_id: int = int(example.get("id"))
|
||||
example_question: str = example.get("question")
|
||||
if not example_question or not example_id:
|
||||
continue
|
||||
if len(example_ids) > 0 and example_id not in example_ids:
|
||||
continue
|
||||
|
||||
example_question = example["question"]
|
||||
target_sub_questions = example.get("target_sub_questions", [])
|
||||
num_target_sub_questions = len(target_sub_questions)
|
||||
search_request = SearchRequest(query=example_question)
|
||||
logger.info(f"{query_start_time} -- Processing example {example_id}")
|
||||
|
||||
config, search_tool = get_test_config(
|
||||
db_session=db_session,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_request=search_request,
|
||||
)
|
||||
try:
|
||||
example_question = example["question"]
|
||||
target_sub_questions = example.get("target_sub_questions", [])
|
||||
num_target_sub_questions = len(target_sub_questions)
|
||||
search_request = SearchRequest(query=example_question)
|
||||
|
||||
inputs = MainInput()
|
||||
initial_answer_duration: timedelta | None = None
|
||||
refined_answer_duration: timedelta | None = None
|
||||
base_answer_duration: timedelta | None = None
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
logger.debug("\n\nTEST QUERY START\n\n")
|
||||
|
||||
question_result = compiled_graph.invoke(
|
||||
input=inputs, config={"metadata": {"config": config}}
|
||||
)
|
||||
end_time = datetime.datetime.now()
|
||||
graph = main_graph_builder_a()
|
||||
compiled_graph = graph.compile()
|
||||
query_end_time = datetime.now()
|
||||
|
||||
duration = end_time - start_time
|
||||
if num_target_sub_questions > 0:
|
||||
chunk_expansion_ratio = (
|
||||
question_result["initial_agent_stats"]
|
||||
.get("agent_effectiveness", {})
|
||||
.get("utilized_chunk_ratio", None)
|
||||
search_request = SearchRequest(
|
||||
# query="what can you do with gitlab?",
|
||||
# query="What are the guiding principles behind the development of cockroachDB",
|
||||
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
||||
# query="When was Washington born?",
|
||||
# query="What is Onyx?",
|
||||
# query="What is the difference between astronomy and astrology?",
|
||||
query=example_question,
|
||||
)
|
||||
support_effectiveness_ratio = (
|
||||
question_result["initial_agent_stats"]
|
||||
.get("agent_effectiveness", {})
|
||||
.get("support_ratio", None)
|
||||
)
|
||||
else:
|
||||
chunk_expansion_ratio = None
|
||||
support_effectiveness_ratio = None
|
||||
|
||||
generated_sub_questions = question_result.get("generated_sub_questions", [])
|
||||
num_generated_sub_questions = len(generated_sub_questions)
|
||||
base_answer = question_result["initial_base_answer"].split("==")[-1]
|
||||
agent_answer = question_result["initial_answer"].split("==")[-1]
|
||||
answer_tokens: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
output_point = {
|
||||
"example_id": example_id,
|
||||
"question": example_question,
|
||||
"duration": duration,
|
||||
"target_sub_questions": target_sub_questions,
|
||||
"generated_sub_questions": generated_sub_questions,
|
||||
"num_target_sub_questions": num_target_sub_questions,
|
||||
"num_generated_sub_questions": num_generated_sub_questions,
|
||||
"chunk_expansion_ratio": chunk_expansion_ratio,
|
||||
"support_effectiveness_ratio": support_effectiveness_ratio,
|
||||
"base_answer": base_answer,
|
||||
"agent_answer": agent_answer,
|
||||
}
|
||||
with get_session_context_manager() as db_session:
|
||||
config = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
assert (
|
||||
config.persistence is not None
|
||||
), "set a chat session id to run this test"
|
||||
|
||||
output_data.append(output_point)
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.behavior.perform_initial_search_decomposition = True
|
||||
input = MainInput_a()
|
||||
|
||||
# Base Flow
|
||||
base_flow_start_time: datetime = datetime.now()
|
||||
for output in run_basic_graph(config):
|
||||
if isinstance(output, OnyxAnswerPiece):
|
||||
answer_tokens["base_answer"].append(output.answer_piece or "")
|
||||
|
||||
output_data["base_answer"] = "".join(answer_tokens["base_answer"])
|
||||
output_data["base_answer_duration"] = (
|
||||
datetime.now() - base_flow_start_time
|
||||
)
|
||||
|
||||
# Agent Flow
|
||||
agent_flow_start_time: datetime = datetime.now()
|
||||
config = get_test_config(
|
||||
db_session,
|
||||
primary_llm,
|
||||
fast_llm,
|
||||
search_request,
|
||||
use_agentic_search=True,
|
||||
)
|
||||
|
||||
config.tooling.force_use_tool = ForceUseTool(
|
||||
force_use=True, tool_name=SearchTool._NAME
|
||||
)
|
||||
|
||||
tool_responses: list = []
|
||||
|
||||
sub_question_dict_tokens: dict[int, dict[int, str]] = defaultdict(
|
||||
lambda: defaultdict(str)
|
||||
)
|
||||
|
||||
for output in run_main_graph(config):
|
||||
if isinstance(output, AgentAnswerPiece):
|
||||
if output.level == 0 and output.level_question_num == 0:
|
||||
answer_tokens["initial"].append(output.answer_piece)
|
||||
elif output.level == 1 and output.level_question_num == 0:
|
||||
answer_tokens["refined"].append(output.answer_piece)
|
||||
elif isinstance(output, SubQuestionPiece):
|
||||
if (
|
||||
output.level is not None
|
||||
and output.level_question_num is not None
|
||||
):
|
||||
sub_question_dict_tokens[output.level][
|
||||
output.level_question_num
|
||||
] += output.sub_question
|
||||
elif isinstance(output, StreamStopInfo):
|
||||
if (
|
||||
output.stream_type == StreamType.MAIN_ANSWER
|
||||
and output.level == 0
|
||||
):
|
||||
initial_answer_duration = (
|
||||
datetime.now() - agent_flow_start_time
|
||||
)
|
||||
elif isinstance(output, RefinedAnswerImprovement):
|
||||
output_data["refined_answer_improves_on_initial_answer"] = str(
|
||||
output.refined_answer_improvement
|
||||
)
|
||||
|
||||
refined_answer_duration = datetime.now() - agent_flow_start_time
|
||||
|
||||
output_data["example_id"] = example_id
|
||||
output_data["question"] = example_question
|
||||
output_data["initial_answer"] = "".join(answer_tokens["initial"])
|
||||
output_data["refined_answer"] = "".join(answer_tokens["refined"])
|
||||
output_data["initial_answer_duration"] = initial_answer_duration or ""
|
||||
output_data["refined_answer_duration"] = refined_answer_duration
|
||||
|
||||
output_data["initial_sub_questions"] = "\n---\n".join(
|
||||
[x for x in sub_question_dict_tokens[0].values()]
|
||||
)
|
||||
output_data["refined_sub_questions"] = "\n---\n".join(
|
||||
[x for x in sub_question_dict_tokens[1].values()]
|
||||
)
|
||||
|
||||
csv_output_data.append(
|
||||
[
|
||||
str(example_id),
|
||||
example_question,
|
||||
output_data["base_answer"],
|
||||
output_data["base_answer_duration"],
|
||||
output_data["initial_sub_questions"],
|
||||
output_data["initial_answer"],
|
||||
output_data["initial_answer_duration"],
|
||||
output_data["refined_sub_questions"],
|
||||
output_data["refined_answer"],
|
||||
output_data["refined_answer_duration"],
|
||||
output_data["refined_answer_improves_on_initial_answer"],
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing example {example_id}: {e}")
|
||||
failed_example_ids.append(example_id)
|
||||
continue
|
||||
|
||||
|
||||
with open(output_file, "w", newline="") as csvfile:
|
||||
fieldnames = [
|
||||
"example_id",
|
||||
"question",
|
||||
"duration",
|
||||
"target_sub_questions",
|
||||
"generated_sub_questions",
|
||||
"num_target_sub_questions",
|
||||
"num_generated_sub_questions",
|
||||
"chunk_expansion_ratio",
|
||||
"support_effectiveness_ratio",
|
||||
"base_answer",
|
||||
"agent_answer",
|
||||
]
|
||||
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter="\t")
|
||||
writer.writeheader()
|
||||
writer.writerows(output_data)
|
||||
writer = csv.writer(csvfile, delimiter="\t")
|
||||
writer.writerow(
|
||||
[
|
||||
"example_id",
|
||||
"question",
|
||||
"base_answer",
|
||||
"base_answer_duration",
|
||||
"initial_sub_questions",
|
||||
"initial_answer",
|
||||
"initial_answer_duration",
|
||||
"refined_sub_questions",
|
||||
"refined_answer",
|
||||
"refined_answer_duration",
|
||||
"refined_answer_improves_on_initial_answer",
|
||||
]
|
||||
)
|
||||
writer.writerows(csv_output_data)
|
||||
|
||||
print("DONE")
|
||||
|
Reference in New Issue
Block a user