reworked config to have logical structure

This commit is contained in:
Evan Lohn 2025-01-31 15:37:47 -08:00
parent 8342168658
commit 118e8afbef
33 changed files with 296 additions and 426 deletions

View File

@ -109,7 +109,7 @@ if __name__ == "__main__":
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
agent_search_config, search_tool = get_test_config(
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = AnswerQuestionInput(
@ -119,7 +119,7 @@ if __name__ == "__main__":
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_search_config}},
config={"configurable": {"config": graph_config}},
# debug=True,
# subgraphs=True,
):

View File

@ -11,7 +11,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
QACheckUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -47,8 +47,8 @@ def check_sub_answer(
)
]
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
fast_llm = agent_searchch_config.fast_llm
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
response = list(
fast_llm.stream(
prompt=msg,

View File

@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
QAGenerationUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
@ -42,13 +42,13 @@ def generate_sub_answer(
) -> QAGenerationUpdate:
node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
state.verified_reranked_documents
level, question_nr = parse_question_id(state.question_id)
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
agent_search_config.search_request.persona
graph_config.inputs.search_request.persona
).contextualized_prompt
if len(context_docs) == 0:
@ -64,10 +64,10 @@ def generate_sub_answer(
writer,
)
else:
fast_llm = agent_search_config.fast_llm
fast_llm = graph_config.tooling.fast_llm
msg = build_sub_question_answer_prompt(
question=question,
original_question=agent_search_config.search_request.query,
original_question=graph_config.inputs.search_request.query,
docs=context_docs,
persona_specification=persona_contextualized_prompt,
config=fast_llm.config,

View File

@ -19,7 +19,7 @@ from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
@ -63,9 +63,9 @@ def generate_initial_answer(
) -> InitialAnswerUpdate:
node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_search_config.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config)
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
sub_questions_cited_documents = state.cited_documents
orig_question_retrieval_documents = state.orig_question_retrieval_documents
@ -93,8 +93,9 @@ def generate_initial_answer(
# Use the query info from the base document retrieval
query_info = get_query_info(state.orig_question_query_retrieval_results)
if agent_search_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
@ -103,7 +104,7 @@ def generate_initial_answer(
final_context_sections=relevant_docs,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=agent_search_config.search_tool,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
@ -167,7 +168,7 @@ def generate_initial_answer(
sub_question_answer_str = ""
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
model = agent_search_config.fast_llm
model = graph_config.tooling.fast_llm
doc_context = format_docs(relevant_docs)
doc_context = trim_prompt_piece(

View File

@ -18,7 +18,7 @@ from onyx.agents.agent_search.deep_search.main.operations import (
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
@ -44,25 +44,18 @@ def decompose_orig_question(
) -> InitialQuestionDecompositionUpdate:
node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_search_config.search_request.query
chat_session_id = agent_search_config.chat_session_id
primary_message_id = agent_search_config.message_id
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
perform_initial_search_decomposition = (
agent_search_config.perform_initial_search_decomposition
graph_config.behavior.perform_initial_search_decomposition
)
# Get the rewritten queries in a defined format
model = agent_search_config.fast_llm
model = graph_config.tooling.fast_llm
history = build_history_prompt(agent_search_config, question)
history = build_history_prompt(graph_config, question)
# Use the initial search results to inform the decomposition
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
if not chat_session_id or not primary_message_id:
raise ValueError(
"chat_session_id and message_id must be provided for agent search"
)
agent_start_time = datetime.now()
# Initial search to inform decomposition. Just get top 3 fits

View File

@ -6,7 +6,7 @@ from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -16,9 +16,9 @@ def format_orig_question_search_input(
state: CoreState, config: RunnableConfig
) -> ExpandedRetrievalInput:
logger.debug("generate_raw_search_data")
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
return ExpandedRetrievalInput(
question=agent_search_config.search_request.query,
question=graph_config.inputs.search_request.query,
base_search=True,
sub_question_id=None, # This graph is always and only used for the original question
log_messages=[],

View File

@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinementUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.utils.logger import setup_logger
@ -26,12 +26,12 @@ logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
agent_config = cast(GraphConfig, config["metadata"]["config"])
if state.tool_choice is not None:
if (
agent_config.use_agentic_search
and agent_config.search_tool is not None
and state.tool_choice.tool.name == agent_config.search_tool.name
agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
):
return "start_agent_search"
else:

View File

@ -221,17 +221,17 @@ if __name__ == "__main__":
with get_session_context_manager() as db_session:
search_request = SearchRequest(query="Who created Excel?")
agent_search_config, search_tool = get_test_config(
graph_config = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(
base_question=agent_search_config.search_request.query, log_messages=[]
base_question=graph_config.inputs.search_request.query, log_messages=[]
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_search_config}},
config={"configurable": {"config": graph_config}},
# stream_mode="debug",
# debug=True,
subgraphs=True,

View File

@ -9,7 +9,7 @@ from onyx.agents.agent_search.deep_search.main.states import (
InitialVRefinedAnswerComparisonUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@ -23,8 +23,8 @@ def compare_answers(
) -> InitialVRefinedAnswerComparisonUpdate:
node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_search_config.search_request.query
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
initial_answer = state.initial_answer
refined_answer = state.refined_answer
@ -35,7 +35,7 @@ def compare_answers(
msg = [HumanMessage(content=compare_answers_prompt)]
# Get the rewritten queries in a defined format
model = agent_search_config.fast_llm
model = graph_config.tooling.fast_llm
# no need to stream this
resp = model.invoke(msg)

View File

@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
@ -39,13 +39,13 @@ def create_refined_sub_questions(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedQuestionDecompositionUpdate:
""" """
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
write_custom_event(
"start_refined_answer_creation",
ToolCallKickoff(
tool_name="agent_search_1",
tool_args={
"query": agent_search_config.search_request.query,
"query": graph_config.inputs.search_request.query,
"answer": state.initial_answer,
},
),
@ -56,9 +56,9 @@ def create_refined_sub_questions(
agent_refined_start_time = datetime.now()
question = agent_search_config.search_request.query
question = graph_config.inputs.search_request.query
base_answer = state.initial_answer
history = build_history_prompt(agent_search_config, question)
history = build_history_prompt(graph_config, question)
# get the entity term extraction dict and properly format it
entity_retlation_term_extractions = state.entity_relation_term_extractions
@ -90,7 +90,7 @@ def create_refined_sub_questions(
]
# Grader
model = agent_search_config.fast_llm
model = graph_config.tooling.fast_llm
streamed_tokens = dispatch_separated(
model.stream(msg), dispatch_subquestion(1, writer)

View File

@ -7,7 +7,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinementUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
@ -18,7 +18,7 @@ def decide_refinement_need(
) -> RequireRefinementUpdate:
node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
decision = True # TODO: just for current testing purposes
@ -31,7 +31,7 @@ def decide_refinement_need(
)
]
if agent_search_config.allow_refinement:
if graph_config.behavior.allow_refinement:
return RequireRefinementUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,

View File

@ -11,7 +11,7 @@ from onyx.agents.agent_search.deep_search.main.states import (
EntityTermExtractionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
@ -33,8 +33,8 @@ def extract_entities_terms(
) -> EntityTermExtractionUpdate:
node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
if not agent_search_config.allow_refinement:
graph_config = cast(GraphConfig, config["metadata"]["config"])
if not graph_config.behavior.allow_refinement:
return EntityTermExtractionUpdate(
entity_relation_term_extractions=EntityRelationshipTermExtraction(
entities=[],
@ -52,21 +52,21 @@ def extract_entities_terms(
)
# first four lines duplicates from generate_initial_answer
question = agent_search_config.search_request.query
question = graph_config.inputs.search_request.query
initial_search_docs = state.exploratory_search_results[:15]
# start with the entity/term/extraction
doc_context = format_docs(initial_search_docs)
doc_context = trim_prompt_piece(
agent_search_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
graph_config.tooling.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
)
msg = [
HumanMessage(
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
)
]
fast_llm = agent_search_config.fast_llm
fast_llm = graph_config.tooling.fast_llm
# Grader
llm_response = fast_llm.invoke(
prompt=msg,

View File

@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
@ -61,9 +61,9 @@ def generate_refined_answer(
) -> RefinedAnswerUpdate:
node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_search_config.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config)
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
persona_contextualized_prompt = (
prompt_enrichment_components.persona_prompts.contextualized_prompt
@ -93,8 +93,9 @@ def generate_refined_answer(
)
query_info = get_query_info(state.orig_question_query_retrieval_results)
if agent_search_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
# stream refined answer docs
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
@ -103,7 +104,7 @@ def generate_refined_answer(
final_context_sections=relevant_docs,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=agent_search_config.search_tool,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
@ -189,7 +190,7 @@ def generate_refined_answer(
else:
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
model = agent_search_config.fast_llm
model = graph_config.tooling.fast_llm
relevant_docs_str = format_docs(relevant_docs)
relevant_docs_str = trim_prompt_piece(
model.config,

View File

@ -10,7 +10,7 @@ from onyx.agents.agent_search.deep_search.main.models import AgentTimings
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainOutput
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@ -59,21 +59,23 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
)
persona_id = None
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
if agent_search_config.search_request.persona:
persona_id = agent_search_config.search_request.persona.id
graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.inputs.search_request.persona:
persona_id = graph_config.inputs.search_request.persona.id
user_id = None
if agent_search_config.search_tool is not None:
user = agent_search_config.search_tool.user
if user:
user_id = user.id
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
user = graph_config.tooling.search_tool.user
if user:
user_id = user.id
# log the agent metrics
if agent_search_config.db_session is not None:
if graph_config.persistence:
if agent_base_duration is not None:
log_agent_metrics(
db_session=agent_search_config.db_session,
db_session=graph_config.persistence.db_session,
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
@ -81,19 +83,18 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
agent_metrics=combined_agent_metrics,
)
if agent_search_config.use_agentic_persistence:
# Persist the sub-answer in the database
db_session = agent_search_config.db_session
chat_session_id = agent_search_config.chat_session_id
primary_message_id = agent_search_config.message_id
sub_question_answer_results = state.sub_question_results
# Persist the sub-answer in the database
db_session = graph_config.persistence.db_session
chat_session_id = graph_config.persistence.chat_session_id
primary_message_id = graph_config.persistence.message_id
sub_question_answer_results = state.sub_question_results
log_agent_sub_question_results(
db_session=db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
log_agent_sub_question_results(
db_session=db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
main_output = MainOutput(
log_messages=[

View File

@ -7,7 +7,7 @@ from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
@ -24,24 +24,14 @@ def start_agent_search(
) -> ExploratorySearchUpdate:
node_start_time = datetime.now()
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_search_config.search_request.query
chat_session_id = agent_search_config.chat_session_id
primary_message_id = agent_search_config.message_id
agent_search_config.fast_llm
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
history = build_history_prompt(agent_search_config, question)
if chat_session_id is None or primary_message_id is None:
raise ValueError(
"chat_session_id and message_id must be provided for agent search"
)
history = build_history_prompt(graph_config, question)
# Initial search to inform decomposition. Just get top 3 fits
search_tool = agent_search_config.search_tool
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
search_tool = graph_config.tooling.search_tool
assert search_tool, "search_tool must be provided for agentic search"
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]

View File

@ -10,15 +10,15 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
def parallel_retrieval_edge(
state: ExpandedRetrievalState, config: RunnableConfig
) -> list[Send | Hashable]:
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question if state.question else agent_search_config.search_request.query
state.question if state.question else graph_config.inputs.search_request.query
)
query_expansions = state.expanded_queries + [question]

View File

@ -129,7 +129,7 @@ if __name__ == "__main__":
)
with get_session_context_manager() as db_session:
agent_search_config, search_tool = get_test_config(
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = ExpandedRetrievalInput(
@ -140,7 +140,7 @@ if __name__ == "__main__":
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": agent_search_config}},
config={"configurable": {"config": graph_config}},
# debug=True,
subgraphs=True,
):

View File

@ -15,7 +15,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import (
REWRITE_PROMPT_MULTI_ORIGINAL,
)
@ -34,21 +34,17 @@ def expand_queries(
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
# When we are running this node on the original question, no question is explictly passed in.
# Instead, we use the original question from the search request.
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
node_start_time = datetime.now()
question = state.question
llm = agent_search_config.fast_llm
chat_session_id = agent_search_config.chat_session_id
llm = graph_config.tooling.fast_llm
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_nr = 0, 0
else:
level, question_nr = parse_question_id(sub_question_id)
if chat_session_id is None:
raise ValueError("chat_session_id must be provided for agent search")
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),

View File

@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
@ -33,7 +33,7 @@ def format_results(
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
query_info = get_query_info(state.query_retrieval_results)
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
# main question docs will be sent later after aggregation and deduping with sub-question docs
reranked_documents = state.reranked_documents
@ -44,8 +44,9 @@ def format_results(
# the top 3 for that one. We may want to revisit this.
reranked_documents = state.query_retrieval_results[-1].search_results[:3]
if agent_search_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
@ -54,7 +55,7 @@ def format_results(
final_context_sections=reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=agent_search_config.search_tool,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",

View File

@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -36,12 +36,13 @@ def rerank_documents(
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question if state.question else agent_search_config.search_request.query
state.question if state.question else graph_config.inputs.search_request.query
)
if agent_search_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
with get_session_context_manager() as db_session:
# we ignore some of the user specified fields since this search is
# internal to agentic search, but we still want to pass through
@ -49,13 +50,13 @@ def rerank_documents(
# (to not make an unnecessary db call).
search_request = SearchRequest(
query=question,
persona=agent_search_config.search_request.persona,
rerank_settings=agent_search_config.search_request.rerank_settings,
persona=graph_config.inputs.search_request.persona,
rerank_settings=graph_config.inputs.search_request.rerank_settings,
)
_search_query = retrieval_preprocessing(
search_request=search_request,
user=agent_search_config.search_tool.user, # bit of a hack
llm=agent_search_config.fast_llm,
user=graph_config.tooling.search_tool.user, # bit of a hack
llm=graph_config.tooling.fast_llm,
db_session=db_session,
)

View File

@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -45,8 +45,8 @@ def retrieve_documents(
"""
node_start_time = datetime.now()
query_to_retrieve = state.query_to_retrieve
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
search_tool = agent_search_config.search_tool
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
retrieved_docs: list[InferenceSection] = []
if not query_to_retrieve.strip():

View File

@ -9,7 +9,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
@ -34,8 +34,8 @@ def verify_documents(
retrieved_document_to_verify = state.retrieved_document_to_verify
document_content = retrieved_document_to_verify.combined_content
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
fast_llm = agent_search_config.fast_llm
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
document_content = trim_prompt_piece(
fast_llm.config, document_content, VERIFIER_PROMPT + question

View File

@ -14,85 +14,15 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
class AgentSearchConfig(BaseModel):
"""
Configuration for the Agent Search feature.
"""
# The search request that was used to generate the Agent Search
search_request: SearchRequest
primary_llm: LLM
fast_llm: LLM
# Whether to force use of a tool, or to
# force tool args IF the tool is used
force_use_tool: ForceUseTool
# contains message history for the current chat session
# has the following (at most one is non-None)
# message_history: list[PreviousMessage] | None = None
# single_message_history: str | None = None
prompt_builder: AnswerPromptBuilder
search_tool: SearchTool | None = None
use_agentic_search: bool = False
# For persisting agent search data
chat_session_id: UUID | None = None
# The message ID of the user message that triggered the Pro Search
message_id: int | None = None
# Whether to persist data for Agentic Search
use_agentic_persistence: bool = True
# The database session for Agentic Search
db_session: Session | None = None
# Whether to perform initial search to inform decomposition
# perform_initial_search_path_decision: bool = True
# Whether to perform initial search to inform decomposition
perform_initial_search_decomposition: bool = True
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
# Tools available for use
tools: list[Tool] | None = None
using_tool_calling_llm: bool = False
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def validate_db_session(self) -> "AgentSearchConfig":
if self.use_agentic_persistence and self.db_session is None:
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
return self
@model_validator(mode="after")
def validate_search_tool(self) -> "AgentSearchConfig":
if self.use_agentic_search and self.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
return self
class Config:
arbitrary_types_allowed = True
class GraphInputs(BaseModel):
"""Input data required for the graph execution"""
search_request: SearchRequest
# contains message history for the current chat session
# has the following (at most one is non-None)
# TODO: unify this into a single message history
# message_history: list[PreviousMessage] | None = None
# single_message_history: str | None = None
prompt_builder: AnswerPromptBuilder
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
@ -107,7 +37,9 @@ class GraphTooling(BaseModel):
primary_llm: LLM
fast_llm: LLM
search_tool: SearchTool | None = None
tools: list[Tool] | None = None
tools: list[Tool]
# Whether to force use of a tool, or to
# force tool args IF the tool is used
force_use_tool: ForceUseTool
using_tool_calling_llm: bool = False
@ -118,41 +50,41 @@ class GraphTooling(BaseModel):
class GraphPersistence(BaseModel):
"""Configuration for data persistence"""
chat_session_id: UUID | None = None
message_id: int | None = None
use_agentic_persistence: bool = True
db_session: Session | None = None
chat_session_id: UUID
# The message ID of the to-be-created first agent message
# in response to the user message that triggered the Pro Search
message_id: int
# The database session the user and initial agent
# message were flushed to; only needed for agentic search
db_session: Session
class Config:
arbitrary_types_allowed = True
@model_validator(mode="after")
def validate_db_session(self) -> "GraphPersistence":
if self.use_agentic_persistence and self.db_session is None:
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
return self
class SearchBehaviorConfig(BaseModel):
class GraphSearchConfig(BaseModel):
"""Configuration controlling search behavior"""
use_agentic_search: bool = False
# Whether to perform initial search to inform decomposition
perform_initial_search_decomposition: bool = True
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
class GraphConfig(BaseModel):
"""
Main configuration class that combines all config components for Langgraph execution
Main container for data needed for Langgraph execution
"""
inputs: GraphInputs
tooling: GraphTooling
persistence: GraphPersistence
behavior: SearchBehaviorConfig
behavior: GraphSearchConfig
# Only needed for agentic search
persistence: GraphPersistence | None = None
@model_validator(mode="after")
def validate_search_tool(self) -> "GraphConfig":

View File

@ -7,7 +7,7 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
@ -23,14 +23,14 @@ logger = setup_logger()
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
structured_response_format = agent_config.structured_response_format
llm = agent_config.primary_llm
agent_config = cast(GraphConfig, config["metadata"]["config"])
structured_response_format = agent_config.inputs.structured_response_format
llm = agent_config.tooling.primary_llm
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Tool choice is None")
tool = tool_choice.tool
prompt_builder = agent_config.prompt_builder
prompt_builder = agent_config.inputs.prompt_builder
if state.tool_call_output is None:
raise ValueError("Tool call output is None")
tool_call_output = state.tool_call_output
@ -41,7 +41,7 @@ def basic_use_tool_response(
prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary,
tool_responses=tool_call_responses,
using_tool_calling_llm=agent_config.using_tool_calling_llm,
using_tool_calling_llm=agent_config.tooling.using_tool_calling_llm,
)
final_search_results = []
@ -58,7 +58,7 @@ def basic_use_tool_response(
initial_search_results = cast(list[LlmDoc], initial_search_results)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.skip_gen_ai_answer_generation:
if not agent_config.behavior.skip_gen_ai_answer_generation:
stream = llm.stream(
prompt=new_prompt_builder.build(),
structured_response_format=structured_response_format,

View File

@ -6,7 +6,7 @@ from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
@ -36,16 +36,18 @@ def llm_tool_choice(
"""
should_stream_answer = state.should_stream_answer
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
using_tool_calling_llm = agent_config.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
agent_config = cast(GraphConfig, config["metadata"]["config"])
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
llm = agent_config.primary_llm
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
structured_response_format = agent_config.structured_response_format
tools = [tool for tool in (agent_config.tools or []) if tool.name in state.tools]
force_use_tool = agent_config.force_use_tool
structured_response_format = agent_config.inputs.structured_response_format
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@ -103,7 +105,8 @@ def llm_tool_choice(
tool_message = process_llm_stream(
stream,
should_stream_answer and not agent_config.skip_gen_ai_answer_generation,
should_stream_answer
and not agent_config.behavior.skip_gen_ai_answer_generation,
writer,
)

View File

@ -3,15 +3,15 @@ from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
agent_config = cast(GraphConfig, config["metadata"]["config"])
return ToolChoiceInput(
# NOTE: this node is used at the top level of the agent, so we always stream
should_stream_answer=True,
prompt_snapshot=None, # uses default prompt builder
tools=[tool.name for tool in (agent_config.tools or [])],
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
)

View File

@ -5,7 +5,7 @@ from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
@ -31,7 +31,7 @@ def tool_call(
) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
cast(AgentSearchConfig, config["metadata"]["config"])
cast(GraphConfig, config["metadata"]["config"])
tool_choice = state.tool_choice
if tool_choice is None:

View File

@ -1,5 +1,3 @@
import asyncio
from collections.abc import AsyncIterable
from collections.abc import Iterable
from datetime import datetime
from typing import cast
@ -16,7 +14,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput_a,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
@ -39,14 +37,6 @@ logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
def _set_combined_token_value(
combined_token: str, parsed_object: AgentAnswerPiece
) -> AgentAnswerPiece:
parsed_object.answer_piece = combined_token
return parsed_object
def _parse_agent_event(
event: StreamEvent,
) -> AnswerPacket | None:
@ -84,63 +74,12 @@ def _parse_agent_event(
return None
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
task_references: set[asyncio.Task[StreamEvent]] = set()
def _manage_async_event_streaming(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig | None,
graph_input: MainInput_a | BasicInput,
) -> Iterable[StreamEvent]:
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
message_id = config.message_id if config else None
async for event in compiled_graph.astream_events(
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
# debug=True,
# indicating v2 here deserves further scrutiny
version="v2",
):
yield event
# This might be able to be simplified
def _yield_async_to_sync() -> Iterable[StreamEvent]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _run_async_event_stream()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
task = asyncio.ensure_future(next_coro, loop=loop)
task_references.add(task)
# Run the coroutine to get the next event
event = loop.run_until_complete(task)
yield event
except (StopAsyncIteration, GeneratorExit):
break
finally:
try:
for task in task_references.pop():
task.cancel()
except StopAsyncIteration:
pass
loop.close()
return _yield_async_to_sync()
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig,
config: GraphConfig,
graph_input: BasicInput | MainInput_a,
) -> Iterable[StreamEvent]:
message_id = config.message_id if config else None
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
stream_mode="custom",
input=graph_input,
@ -151,11 +90,13 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: AgentSearchConfig,
config: GraphConfig,
input: BasicInput | MainInput_a,
) -> AnswerStream:
config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED
config.allow_refinement = ALLOW_REFINEMENT
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
)
config.behavior.allow_refinement = ALLOW_REFINEMENT
for event in manage_sync_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
@ -177,21 +118,24 @@ def load_compiled_graph() -> CompiledStateGraph:
def run_main_graph(
config: AgentSearchConfig,
config: GraphConfig,
) -> AnswerStream:
compiled_graph = load_compiled_graph()
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
tool_name="agent_search_0",
tool_args={"query": config.search_request.query},
tool_args={"query": config.inputs.search_request.query},
)
yield from run_graph(compiled_graph, config, input)
def run_basic_graph(
config: AgentSearchConfig,
config: GraphConfig,
) -> AnswerStream:
graph = basic_graph_builder()
compiled_graph = graph.compile()
@ -222,15 +166,16 @@ if __name__ == "__main__":
# Joachim custom persona
with get_session_context_manager() as db_session:
config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
config = get_test_config(db_session, primary_llm, fast_llm, search_request)
assert (
config.persistence is not None
), "set a chat session id to run this test"
# search_request.persona = get_persona_by_id(1, None, db_session)
config.use_agentic_persistence = True
# config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
config.behavior.perform_initial_search_decomposition = True
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
base_question=config.inputs.search_request.query, log_messages=[]
)
# with open("output.txt", "w") as f:

View File

@ -4,7 +4,7 @@ from langchain.schema import SystemMessage
from langchain_core.messages.tool import ToolMessage
from onyx.agents.agent_search.models import AgentPromptEnrichmentComponents
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -80,11 +80,11 @@ def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -
)
def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
prompt_builder = config.prompt_builder
model = config.fast_llm
def build_history_prompt(config: GraphConfig, question: str) -> str:
prompt_builder = config.inputs.prompt_builder
model = config.tooling.fast_llm
persona_base = get_persona_agent_prompt_expressions(
config.search_request.persona
config.inputs.search_request.persona
).base_prompt
if prompt_builder is None:
@ -118,13 +118,13 @@ def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
def get_prompt_enrichment_components(
config: AgentSearchConfig,
config: GraphConfig,
) -> AgentPromptEnrichmentComponents:
persona_prompts = get_persona_agent_prompt_expressions(
config.search_request.persona
config.inputs.search_request.persona
)
history = build_history_prompt(config, config.search_request.query)
history = build_history_prompt(config, config.inputs.search_request.query)
date_str = get_today_prompt()

View File

@ -1,5 +1,6 @@
import ast
import json
import os
import re
from collections.abc import Callable
from collections.abc import Iterator
@ -17,7 +18,11 @@ from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
@ -60,6 +65,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
BaseMessage_Content = str | list[str | dict[str, Any]]
@ -166,7 +172,7 @@ def get_test_config(
fast_llm: LLM,
search_request: SearchRequest,
use_agentic_search: bool = True,
) -> tuple[AgentSearchConfig, SearchTool]:
) -> GraphConfig:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
@ -221,12 +227,8 @@ def get_test_config(
bypass_acl=search_tool_config.bypass_acl,
)
config = AgentSearchConfig(
graph_inputs = GraphInputs(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=search_request.query),
message_history=[],
@ -234,17 +236,42 @@ def get_test_config(
raw_user_query=search_request.query,
raw_user_uploaded_files=[],
),
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
message_id=1,
use_agentic_persistence=True,
db_session=db_session,
tools=[search_tool],
use_agentic_search=use_agentic_search,
structured_response_format=answer_style_config.structured_response_format,
)
return config, search_tool
using_tool_calling_llm = explicit_tool_calling_supported(
primary_llm.config.model_provider, primary_llm.config.model_name
)
graph_tooling = GraphTooling(
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
tools=[search_tool],
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
using_tool_calling_llm=using_tool_calling_llm,
)
graph_persistence = None
if chat_session_id := os.environ.get("ONYX_AS_CHAT_SESSION_ID"):
graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=UUID(chat_session_id),
message_id=1,
)
search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=False,
allow_refinement=True,
)
graph_config = GraphConfig(
inputs=graph_inputs,
tooling=graph_tooling,
persistence=graph_persistence,
behavior=search_behavior_config,
)
return graph_config
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:

View File

@ -4,7 +4,11 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.models import AgentAnswerPiece
@ -16,12 +20,10 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@ -57,35 +59,9 @@ class Answer:
use_agentic_persistence: bool = True,
) -> None:
self.is_connected: Callable[[], bool] | None = is_connected
self.latest_query_files = latest_query_files or []
self.tools = tools or []
self.force_use_tool = force_use_tool
# used for QA flow where we only want to send a single message
self.answer_style_config = answer_style_config
self.llm = llm
self.fast_llm = fast_llm
self.llm_tokenizer = get_tokenizer(
provider_type=llm.config.model_provider,
model_name=llm.config.model_name,
)
self._streamed_output: list[str] | None = None
self._processed_stream: (list[AnswerPacket] | None) = None
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
self.using_tool_calling_llm = (
explicit_tool_calling_supported(
self.llm.config.model_provider, self.llm.config.model_name
)
and not skip_explicit_tool_calling
)
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
search_tool: SearchTool | None = None
@ -95,43 +71,46 @@ class Answer:
elif len(search_tools) == 1:
search_tool = search_tools[0]
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
using_tool_calling_llm = (
explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
and not skip_explicit_tool_calling
)
self.agent_search_config = AgentSearchConfig(
self.graph_inputs = GraphInputs(
search_request=search_request,
prompt_builder=prompt_builder,
files=latest_query_files,
structured_response_format=answer_style_config.structured_response_format,
)
self.graph_tooling = GraphTooling(
primary_llm=llm,
fast_llm=fast_llm,
search_tool=search_tool,
tools=tools or [],
force_use_tool=force_use_tool,
use_agentic_search=use_agentic_search,
chat_session_id=chat_session_id,
message_id=current_agent_message_id,
use_agentic_persistence=use_agentic_persistence,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
tools=tools,
using_tool_calling_llm=using_tool_calling_llm,
files=latest_query_files,
structured_response_format=answer_style_config.structured_response_format,
)
self.graph_persistence = None
if use_agentic_persistence:
assert db_session, "db_session must be provided for agentic persistence"
self.graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=chat_session_id,
message_id=current_agent_message_id,
)
self.search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
allow_refinement=True,
)
self.db_session = db_session
def _get_tools_list(self) -> list[Tool]:
if not self.force_use_tool.force_use:
return self.tools
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
args_str = (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args
else ""
self.graph_config = GraphConfig(
inputs=self.graph_inputs,
tooling=self.graph_tooling,
persistence=self.graph_persistence,
behavior=self.search_behavior_config,
)
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
return [tool]
@property
def processed_streamed_output(self) -> AnswerStream:
@ -141,11 +120,11 @@ class Answer:
run_langgraph = (
run_main_graph
if self.agent_search_config.use_agentic_search
if self.graph_config.behavior.use_agentic_search
else run_basic_graph
)
stream = run_langgraph(
self.agent_search_config,
self.graph_config,
)
processed_stream = []

View File

@ -5,10 +5,10 @@ import os
import yaml
from onyx.agents.agent_search.deep_search.main__graph.graph_builder import (
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder,
)
from onyx.agents.agent_search.deep_search.main__graph.states import MainInput
from onyx.agents.agent_search.deep_search.main.states import MainInput
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager

View File

@ -70,13 +70,13 @@ def answer_instance(
def test_basic_answer(answer_instance: Answer) -> None:
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
mock_llm.stream.return_value = [
AIMessageChunk(content="This is a "),
AIMessageChunk(content="mock answer."),
]
answer_instance.agent_search_config.fast_llm = mock_llm
answer_instance.agent_search_config.primary_llm = mock_llm
answer_instance.graph_config.tooling.fast_llm = mock_llm
answer_instance.graph_config.tooling.primary_llm = mock_llm
output = list(answer_instance.processed_streamed_output)
assert len(output) == 2
@ -128,11 +128,11 @@ def test_answer_with_search_call(
force_use_tool: ForceUseTool,
expected_tool_args: dict,
) -> None:
answer_instance.agent_search_config.tools = [mock_search_tool]
answer_instance.agent_search_config.force_use_tool = force_use_tool
answer_instance.graph_config.tooling.tools = [mock_search_tool]
answer_instance.graph_config.tooling.force_use_tool = force_use_tool
# Set up the LLM mock to return search results and then an answer
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
stream_side_effect: list[list[BaseMessage]] = []
@ -253,10 +253,10 @@ def test_answer_with_search_no_tool_calling(
mock_contexts: OnyxContexts,
mock_search_tool: MagicMock,
) -> None:
answer_instance.agent_search_config.tools = [mock_search_tool]
answer_instance.graph_config.tooling.tools = [mock_search_tool]
# Set up the LLM mock to return an answer
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
mock_llm.stream.return_value = [
AIMessageChunk(content="Based on the search results, "),
AIMessageChunk(content="the answer is abc[1]. "),
@ -264,7 +264,7 @@ def test_answer_with_search_no_tool_calling(
]
# Force non-tool calling behavior
answer_instance.agent_search_config.using_tool_calling_llm = False
answer_instance.graph_config.tooling.using_tool_calling_llm = False
# Process the output
output = list(answer_instance.processed_streamed_output)
@ -319,7 +319,7 @@ def test_answer_with_search_no_tool_calling(
# Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool
mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with(
QUERY, [], answer_instance.llm
QUERY, [], answer_instance.graph_config.tooling.primary_llm
)
# Verify that the search tool's run method was called
@ -329,8 +329,8 @@ def test_answer_with_search_no_tool_calling(
def test_is_cancelled(answer_instance: Answer) -> None:
# Set up the LLM mock to return multiple chunks
mock_llm = Mock()
answer_instance.agent_search_config.primary_llm = mock_llm
answer_instance.agent_search_config.fast_llm = mock_llm
answer_instance.graph_config.tooling.primary_llm = mock_llm
answer_instance.graph_config.tooling.fast_llm = mock_llm
mock_llm.stream.return_value = [
AIMessageChunk(content="This is the "),
AIMessageChunk(content="first part."),