mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 20:08:36 +02:00
reworked config to have logical structure
This commit is contained in:
parent
8342168658
commit
118e8afbef
@ -109,7 +109,7 @@ if __name__ == "__main__":
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_search_config, search_tool = get_test_config(
|
||||
graph_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = AnswerQuestionInput(
|
||||
@ -119,7 +119,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_search_config}},
|
||||
config={"configurable": {"config": graph_config}},
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
|
@ -11,7 +11,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
QACheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -47,8 +47,8 @@ def check_sub_answer(
|
||||
)
|
||||
]
|
||||
|
||||
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_searchch_config.fast_llm
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
|
@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
QAGenerationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
@ -42,13 +42,13 @@ def generate_sub_answer(
|
||||
) -> QAGenerationUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.question
|
||||
state.verified_reranked_documents
|
||||
level, question_nr = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||
agent_search_config.search_request.persona
|
||||
graph_config.inputs.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
if len(context_docs) == 0:
|
||||
@ -64,10 +64,10 @@ def generate_sub_answer(
|
||||
writer,
|
||||
)
|
||||
else:
|
||||
fast_llm = agent_search_config.fast_llm
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
msg = build_sub_question_answer_prompt(
|
||||
question=question,
|
||||
original_question=agent_search_config.search_request.query,
|
||||
original_question=graph_config.inputs.search_request.query,
|
||||
docs=context_docs,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
config=fast_llm.config,
|
||||
|
@ -19,7 +19,7 @@ from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
@ -63,9 +63,9 @@ def generate_initial_answer(
|
||||
) -> InitialAnswerUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_search_config.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config)
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
orig_question_retrieval_documents = state.orig_question_retrieval_documents
|
||||
@ -93,8 +93,9 @@ def generate_initial_answer(
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.orig_question_query_retrieval_results)
|
||||
|
||||
if agent_search_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
for tool_response in yield_search_responses(
|
||||
@ -103,7 +104,7 @@ def generate_initial_answer(
|
||||
final_context_sections=relevant_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=agent_search_config.search_tool,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
@ -167,7 +168,7 @@ def generate_initial_answer(
|
||||
sub_question_answer_str = ""
|
||||
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_search_config.fast_llm
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
doc_context = format_docs(relevant_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
|
@ -18,7 +18,7 @@ from onyx.agents.agent_search.deep_search.main.operations import (
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
@ -44,25 +44,18 @@ def decompose_orig_question(
|
||||
) -> InitialQuestionDecompositionUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_search_config.search_request.query
|
||||
chat_session_id = agent_search_config.chat_session_id
|
||||
primary_message_id = agent_search_config.message_id
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
perform_initial_search_decomposition = (
|
||||
agent_search_config.perform_initial_search_decomposition
|
||||
graph_config.behavior.perform_initial_search_decomposition
|
||||
)
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_search_config.fast_llm
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
history = build_history_prompt(agent_search_config, question)
|
||||
history = build_history_prompt(graph_config, question)
|
||||
|
||||
# Use the initial search results to inform the decomposition
|
||||
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
|
||||
|
||||
if not chat_session_id or not primary_message_id:
|
||||
raise ValueError(
|
||||
"chat_session_id and message_id must be provided for agent search"
|
||||
)
|
||||
agent_start_time = datetime.now()
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
@ -6,7 +6,7 @@ from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -16,9 +16,9 @@ def format_orig_question_search_input(
|
||||
state: CoreState, config: RunnableConfig
|
||||
) -> ExpandedRetrievalInput:
|
||||
logger.debug("generate_raw_search_data")
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
return ExpandedRetrievalInput(
|
||||
question=agent_search_config.search_request.query,
|
||||
question=graph_config.inputs.search_request.query,
|
||||
base_search=True,
|
||||
sub_question_id=None, # This graph is always and only used for the original question
|
||||
log_messages=[],
|
||||
|
@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RequireRefinementUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -26,12 +26,12 @@ logger = setup_logger()
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
if state.tool_choice is not None:
|
||||
if (
|
||||
agent_config.use_agentic_search
|
||||
and agent_config.search_tool is not None
|
||||
and state.tool_choice.tool.name == agent_config.search_tool.name
|
||||
agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
|
||||
):
|
||||
return "start_agent_search"
|
||||
else:
|
||||
|
@ -221,17 +221,17 @@ if __name__ == "__main__":
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
search_request = SearchRequest(query="Who created Excel?")
|
||||
agent_search_config, search_tool = get_test_config(
|
||||
graph_config = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(
|
||||
base_question=agent_search_config.search_request.query, log_messages=[]
|
||||
base_question=graph_config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_search_config}},
|
||||
config={"configurable": {"config": graph_config}},
|
||||
# stream_mode="debug",
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
|
@ -9,7 +9,7 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
InitialVRefinedAnswerComparisonUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@ -23,8 +23,8 @@ def compare_answers(
|
||||
) -> InitialVRefinedAnswerComparisonUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_search_config.search_request.query
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
initial_answer = state.initial_answer
|
||||
refined_answer = state.refined_answer
|
||||
|
||||
@ -35,7 +35,7 @@ def compare_answers(
|
||||
msg = [HumanMessage(content=compare_answers_prompt)]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = agent_search_config.fast_llm
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
|
@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RefinedQuestionDecompositionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
@ -39,13 +39,13 @@ def create_refined_sub_questions(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> RefinedQuestionDecompositionUpdate:
|
||||
""" """
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
write_custom_event(
|
||||
"start_refined_answer_creation",
|
||||
ToolCallKickoff(
|
||||
tool_name="agent_search_1",
|
||||
tool_args={
|
||||
"query": agent_search_config.search_request.query,
|
||||
"query": graph_config.inputs.search_request.query,
|
||||
"answer": state.initial_answer,
|
||||
},
|
||||
),
|
||||
@ -56,9 +56,9 @@ def create_refined_sub_questions(
|
||||
|
||||
agent_refined_start_time = datetime.now()
|
||||
|
||||
question = agent_search_config.search_request.query
|
||||
question = graph_config.inputs.search_request.query
|
||||
base_answer = state.initial_answer
|
||||
history = build_history_prompt(agent_search_config, question)
|
||||
history = build_history_prompt(graph_config, question)
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_retlation_term_extractions = state.entity_relation_term_extractions
|
||||
|
||||
@ -90,7 +90,7 @@ def create_refined_sub_questions(
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = agent_search_config.fast_llm
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg), dispatch_subquestion(1, writer)
|
||||
|
@ -7,7 +7,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RequireRefinementUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
@ -18,7 +18,7 @@ def decide_refinement_need(
|
||||
) -> RequireRefinementUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
decision = True # TODO: just for current testing purposes
|
||||
|
||||
@ -31,7 +31,7 @@ def decide_refinement_need(
|
||||
)
|
||||
]
|
||||
|
||||
if agent_search_config.allow_refinement:
|
||||
if graph_config.behavior.allow_refinement:
|
||||
return RequireRefinementUpdate(
|
||||
require_refined_answer_eval=decision,
|
||||
log_messages=log_messages,
|
||||
|
@ -11,7 +11,7 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
EntityTermExtractionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
@ -33,8 +33,8 @@ def extract_entities_terms(
|
||||
) -> EntityTermExtractionUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if not agent_search_config.allow_refinement:
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
if not graph_config.behavior.allow_refinement:
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
@ -52,21 +52,21 @@ def extract_entities_terms(
|
||||
)
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = agent_search_config.search_request.query
|
||||
question = graph_config.inputs.search_request.query
|
||||
initial_search_docs = state.exploratory_search_results[:15]
|
||||
|
||||
# start with the entity/term/extraction
|
||||
doc_context = format_docs(initial_search_docs)
|
||||
|
||||
doc_context = trim_prompt_piece(
|
||||
agent_search_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
||||
graph_config.tooling.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
fast_llm = agent_search_config.fast_llm
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
|
@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
@ -61,9 +61,9 @@ def generate_refined_answer(
|
||||
) -> RefinedAnswerUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_search_config.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config)
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||
|
||||
persona_contextualized_prompt = (
|
||||
prompt_enrichment_components.persona_prompts.contextualized_prompt
|
||||
@ -93,8 +93,9 @@ def generate_refined_answer(
|
||||
)
|
||||
|
||||
query_info = get_query_info(state.orig_question_query_retrieval_results)
|
||||
if agent_search_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
# stream refined answer docs
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
for tool_response in yield_search_responses(
|
||||
@ -103,7 +104,7 @@ def generate_refined_answer(
|
||||
final_context_sections=relevant_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=agent_search_config.search_tool,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
@ -189,7 +190,7 @@ def generate_refined_answer(
|
||||
else:
|
||||
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_search_config.fast_llm
|
||||
model = graph_config.tooling.fast_llm
|
||||
relevant_docs_str = format_docs(relevant_docs)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
|
@ -10,7 +10,7 @@ from onyx.agents.agent_search.deep_search.main.models import AgentTimings
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainOutput
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@ -59,21 +59,23 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
|
||||
)
|
||||
|
||||
persona_id = None
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
if agent_search_config.search_request.persona:
|
||||
persona_id = agent_search_config.search_request.persona.id
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
if graph_config.inputs.search_request.persona:
|
||||
persona_id = graph_config.inputs.search_request.persona.id
|
||||
|
||||
user_id = None
|
||||
if agent_search_config.search_tool is not None:
|
||||
user = agent_search_config.search_tool.user
|
||||
if user:
|
||||
user_id = user.id
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
user = graph_config.tooling.search_tool.user
|
||||
if user:
|
||||
user_id = user.id
|
||||
|
||||
# log the agent metrics
|
||||
if agent_search_config.db_session is not None:
|
||||
if graph_config.persistence:
|
||||
if agent_base_duration is not None:
|
||||
log_agent_metrics(
|
||||
db_session=agent_search_config.db_session,
|
||||
db_session=graph_config.persistence.db_session,
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
agent_type=agent_type,
|
||||
@ -81,19 +83,18 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
|
||||
agent_metrics=combined_agent_metrics,
|
||||
)
|
||||
|
||||
if agent_search_config.use_agentic_persistence:
|
||||
# Persist the sub-answer in the database
|
||||
db_session = agent_search_config.db_session
|
||||
chat_session_id = agent_search_config.chat_session_id
|
||||
primary_message_id = agent_search_config.message_id
|
||||
sub_question_answer_results = state.sub_question_results
|
||||
# Persist the sub-answer in the database
|
||||
db_session = graph_config.persistence.db_session
|
||||
chat_session_id = graph_config.persistence.chat_session_id
|
||||
primary_message_id = graph_config.persistence.message_id
|
||||
sub_question_answer_results = state.sub_question_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
log_agent_sub_question_results(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
|
||||
main_output = MainOutput(
|
||||
log_messages=[
|
||||
|
@ -7,7 +7,7 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
ExploratorySearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
@ -24,24 +24,14 @@ def start_agent_search(
|
||||
) -> ExploratorySearchUpdate:
|
||||
node_start_time = datetime.now()
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_search_config.search_request.query
|
||||
chat_session_id = agent_search_config.chat_session_id
|
||||
primary_message_id = agent_search_config.message_id
|
||||
agent_search_config.fast_llm
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
history = build_history_prompt(agent_search_config, question)
|
||||
|
||||
if chat_session_id is None or primary_message_id is None:
|
||||
raise ValueError(
|
||||
"chat_session_id and message_id must be provided for agent search"
|
||||
)
|
||||
history = build_history_prompt(graph_config, question)
|
||||
|
||||
# Initial search to inform decomposition. Just get top 3 fits
|
||||
|
||||
search_tool = agent_search_config.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
assert search_tool, "search_tool must be provided for agentic search"
|
||||
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
|
||||
|
||||
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
|
||||
|
@ -10,15 +10,15 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
|
||||
|
||||
def parallel_retrieval_edge(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = (
|
||||
state.question if state.question else agent_search_config.search_request.query
|
||||
state.question if state.question else graph_config.inputs.search_request.query
|
||||
)
|
||||
|
||||
query_expansions = state.expanded_queries + [question]
|
||||
|
@ -129,7 +129,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
agent_search_config, search_tool = get_test_config(
|
||||
graph_config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
inputs = ExpandedRetrievalInput(
|
||||
@ -140,7 +140,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
config={"configurable": {"config": agent_search_config}},
|
||||
config={"configurable": {"config": graph_config}},
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
|
@ -15,7 +15,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
)
|
||||
@ -34,21 +34,17 @@ def expand_queries(
|
||||
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
||||
# When we are running this node on the original question, no question is explictly passed in.
|
||||
# Instead, we use the original question from the search request.
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
node_start_time = datetime.now()
|
||||
question = state.question
|
||||
|
||||
llm = agent_search_config.fast_llm
|
||||
chat_session_id = agent_search_config.chat_session_id
|
||||
llm = graph_config.tooling.fast_llm
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_nr = 0, 0
|
||||
else:
|
||||
level, question_nr = parse_question_id(sub_question_id)
|
||||
|
||||
if chat_session_id is None:
|
||||
raise ValueError("chat_session_id must be provided for agent search")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
|
@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||
@ -33,7 +33,7 @@ def format_results(
|
||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||
query_info = get_query_info(state.query_retrieval_results)
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||
|
||||
reranked_documents = state.reranked_documents
|
||||
@ -44,8 +44,9 @@ def format_results(
|
||||
# the top 3 for that one. We may want to revisit this.
|
||||
reranked_documents = state.query_retrieval_results[-1].search_results[:3]
|
||||
|
||||
if agent_search_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
|
||||
relevance_list = relevance_from_docs(reranked_documents)
|
||||
for tool_response in yield_search_responses(
|
||||
@ -54,7 +55,7 @@ def format_results(
|
||||
final_context_sections=reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=agent_search_config.search_tool,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
|
@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -36,12 +36,13 @@ def rerank_documents(
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = (
|
||||
state.question if state.question else agent_search_config.search_request.query
|
||||
state.question if state.question else graph_config.inputs.search_request.query
|
||||
)
|
||||
if agent_search_config.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
with get_session_context_manager() as db_session:
|
||||
# we ignore some of the user specified fields since this search is
|
||||
# internal to agentic search, but we still want to pass through
|
||||
@ -49,13 +50,13 @@ def rerank_documents(
|
||||
# (to not make an unnecessary db call).
|
||||
search_request = SearchRequest(
|
||||
query=question,
|
||||
persona=agent_search_config.search_request.persona,
|
||||
rerank_settings=agent_search_config.search_request.rerank_settings,
|
||||
persona=graph_config.inputs.search_request.persona,
|
||||
rerank_settings=graph_config.inputs.search_request.rerank_settings,
|
||||
)
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=search_request,
|
||||
user=agent_search_config.search_tool.user, # bit of a hack
|
||||
llm=agent_search_config.fast_llm,
|
||||
user=graph_config.tooling.search_tool.user, # bit of a hack
|
||||
llm=graph_config.tooling.fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
RetrievalInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -45,8 +45,8 @@ def retrieve_documents(
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
query_to_retrieve = state.query_to_retrieve
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
search_tool = agent_search_config.search_tool
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
if not query_to_retrieve.strip():
|
||||
|
@ -9,7 +9,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
@ -34,8 +34,8 @@ def verify_documents(
|
||||
retrieved_document_to_verify = state.retrieved_document_to_verify
|
||||
document_content = retrieved_document_to_verify.combined_content
|
||||
|
||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
fast_llm = agent_search_config.fast_llm
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
|
||||
document_content = trim_prompt_piece(
|
||||
fast_llm.config, document_content, VERIFIER_PROMPT + question
|
||||
|
@ -14,85 +14,15 @@ from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
class AgentSearchConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the Agent Search feature.
|
||||
"""
|
||||
|
||||
# The search request that was used to generate the Agent Search
|
||||
search_request: SearchRequest
|
||||
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
|
||||
# Whether to force use of a tool, or to
|
||||
# force tool args IF the tool is used
|
||||
force_use_tool: ForceUseTool
|
||||
|
||||
# contains message history for the current chat session
|
||||
# has the following (at most one is non-None)
|
||||
# message_history: list[PreviousMessage] | None = None
|
||||
# single_message_history: str | None = None
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
use_agentic_search: bool = False
|
||||
|
||||
# For persisting agent search data
|
||||
chat_session_id: UUID | None = None
|
||||
|
||||
# The message ID of the user message that triggered the Pro Search
|
||||
message_id: int | None = None
|
||||
|
||||
# Whether to persist data for Agentic Search
|
||||
use_agentic_persistence: bool = True
|
||||
|
||||
# The database session for Agentic Search
|
||||
db_session: Session | None = None
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
# perform_initial_search_path_decision: bool = True
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_decomposition: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
|
||||
# Tools available for use
|
||||
tools: list[Tool] | None = None
|
||||
|
||||
using_tool_calling_llm: bool = False
|
||||
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_db_session(self) -> "AgentSearchConfig":
|
||||
if self.use_agentic_persistence and self.db_session is None:
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_tool(self) -> "AgentSearchConfig":
|
||||
if self.use_agentic_search and self.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class GraphInputs(BaseModel):
|
||||
"""Input data required for the graph execution"""
|
||||
|
||||
search_request: SearchRequest
|
||||
# contains message history for the current chat session
|
||||
# has the following (at most one is non-None)
|
||||
# TODO: unify this into a single message history
|
||||
# message_history: list[PreviousMessage] | None = None
|
||||
# single_message_history: str | None = None
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
structured_response_format: dict | None = None
|
||||
@ -107,7 +37,9 @@ class GraphTooling(BaseModel):
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
search_tool: SearchTool | None = None
|
||||
tools: list[Tool] | None = None
|
||||
tools: list[Tool]
|
||||
# Whether to force use of a tool, or to
|
||||
# force tool args IF the tool is used
|
||||
force_use_tool: ForceUseTool
|
||||
using_tool_calling_llm: bool = False
|
||||
|
||||
@ -118,41 +50,41 @@ class GraphTooling(BaseModel):
|
||||
class GraphPersistence(BaseModel):
|
||||
"""Configuration for data persistence"""
|
||||
|
||||
chat_session_id: UUID | None = None
|
||||
message_id: int | None = None
|
||||
use_agentic_persistence: bool = True
|
||||
db_session: Session | None = None
|
||||
chat_session_id: UUID
|
||||
# The message ID of the to-be-created first agent message
|
||||
# in response to the user message that triggered the Pro Search
|
||||
message_id: int
|
||||
|
||||
# The database session the user and initial agent
|
||||
# message were flushed to; only needed for agentic search
|
||||
db_session: Session
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_db_session(self) -> "GraphPersistence":
|
||||
if self.use_agentic_persistence and self.db_session is None:
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class SearchBehaviorConfig(BaseModel):
|
||||
class GraphSearchConfig(BaseModel):
|
||||
"""Configuration controlling search behavior"""
|
||||
|
||||
use_agentic_search: bool = False
|
||||
# Whether to perform initial search to inform decomposition
|
||||
perform_initial_search_decomposition: bool = True
|
||||
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
"""
|
||||
Main configuration class that combines all config components for Langgraph execution
|
||||
Main container for data needed for Langgraph execution
|
||||
"""
|
||||
|
||||
inputs: GraphInputs
|
||||
tooling: GraphTooling
|
||||
persistence: GraphPersistence
|
||||
behavior: SearchBehaviorConfig
|
||||
behavior: GraphSearchConfig
|
||||
# Only needed for agentic search
|
||||
persistence: GraphPersistence | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_tool(self) -> "GraphConfig":
|
||||
|
@ -7,7 +7,7 @@ from langgraph.types import StreamWriter
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
@ -23,14 +23,14 @@ logger = setup_logger()
|
||||
def basic_use_tool_response(
|
||||
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BasicOutput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
llm = agent_config.primary_llm
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
structured_response_format = agent_config.inputs.structured_response_format
|
||||
llm = agent_config.tooling.primary_llm
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Tool choice is None")
|
||||
tool = tool_choice.tool
|
||||
prompt_builder = agent_config.prompt_builder
|
||||
prompt_builder = agent_config.inputs.prompt_builder
|
||||
if state.tool_call_output is None:
|
||||
raise ValueError("Tool call output is None")
|
||||
tool_call_output = state.tool_call_output
|
||||
@ -41,7 +41,7 @@ def basic_use_tool_response(
|
||||
prompt_builder=prompt_builder,
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_responses=tool_call_responses,
|
||||
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
||||
using_tool_calling_llm=agent_config.tooling.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
final_search_results = []
|
||||
@ -58,7 +58,7 @@ def basic_use_tool_response(
|
||||
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.skip_gen_ai_answer_generation:
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
stream = llm.stream(
|
||||
prompt=new_prompt_builder.build(),
|
||||
structured_response_format=structured_response_format,
|
||||
|
@ -6,7 +6,7 @@ from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
@ -36,16 +36,18 @@ def llm_tool_choice(
|
||||
"""
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
llm = agent_config.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
||||
llm = agent_config.tooling.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
||||
|
||||
structured_response_format = agent_config.structured_response_format
|
||||
tools = [tool for tool in (agent_config.tools or []) if tool.name in state.tools]
|
||||
force_use_tool = agent_config.force_use_tool
|
||||
structured_response_format = agent_config.inputs.structured_response_format
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
@ -103,7 +105,8 @@ def llm_tool_choice(
|
||||
|
||||
tool_message = process_llm_stream(
|
||||
stream,
|
||||
should_stream_answer and not agent_config.skip_gen_ai_answer_generation,
|
||||
should_stream_answer
|
||||
and not agent_config.behavior.skip_gen_ai_answer_generation,
|
||||
writer,
|
||||
)
|
||||
|
||||
|
@ -3,15 +3,15 @@ from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
|
||||
|
||||
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
return ToolChoiceInput(
|
||||
# NOTE: this node is used at the top level of the agent, so we always stream
|
||||
should_stream_answer=True,
|
||||
prompt_snapshot=None, # uses default prompt builder
|
||||
tools=[tool.name for tool in (agent_config.tools or [])],
|
||||
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
|
||||
)
|
||||
|
@ -5,7 +5,7 @@ from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
@ -31,7 +31,7 @@ def tool_call(
|
||||
) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
|
||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
|
@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
@ -16,7 +14,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput_a,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
@ -39,14 +37,6 @@ logger = setup_logger()
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
|
||||
|
||||
def _set_combined_token_value(
|
||||
combined_token: str, parsed_object: AgentAnswerPiece
|
||||
) -> AgentAnswerPiece:
|
||||
parsed_object.answer_piece = combined_token
|
||||
|
||||
return parsed_object
|
||||
|
||||
|
||||
def _parse_agent_event(
|
||||
event: StreamEvent,
|
||||
) -> AnswerPacket | None:
|
||||
@ -84,63 +74,12 @@ def _parse_agent_event(
|
||||
return None
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
|
||||
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
|
||||
task_references: set[asyncio.Task[StreamEvent]] = set()
|
||||
|
||||
|
||||
def _manage_async_event_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig | None,
|
||||
graph_input: MainInput_a | BasicInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
|
||||
message_id = config.message_id if config else None
|
||||
async for event in compiled_graph.astream_events(
|
||||
input=graph_input,
|
||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
||||
# debug=True,
|
||||
# indicating v2 here deserves further scrutiny
|
||||
version="v2",
|
||||
):
|
||||
yield event
|
||||
|
||||
# This might be able to be simplified
|
||||
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _run_async_event_stream()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
task = asyncio.ensure_future(next_coro, loop=loop)
|
||||
task_references.add(task)
|
||||
# Run the coroutine to get the next event
|
||||
event = loop.run_until_complete(task)
|
||||
yield event
|
||||
except (StopAsyncIteration, GeneratorExit):
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
for task in task_references.pop():
|
||||
task.cancel()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
loop.close()
|
||||
|
||||
return _yield_async_to_sync()
|
||||
|
||||
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput_a,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.message_id if config else None
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
stream_mode="custom",
|
||||
input=graph_input,
|
||||
@ -151,11 +90,13 @@ def manage_sync_streaming(
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: AgentSearchConfig,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
config.allow_refinement = ALLOW_REFINEMENT
|
||||
config.behavior.perform_initial_search_decomposition = (
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
)
|
||||
config.behavior.allow_refinement = ALLOW_REFINEMENT
|
||||
|
||||
for event in manage_sync_streaming(
|
||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||
@ -177,21 +118,24 @@ def load_compiled_graph() -> CompiledStateGraph:
|
||||
|
||||
|
||||
def run_main_graph(
|
||||
config: AgentSearchConfig,
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph()
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.search_request.query},
|
||||
tool_args={"query": config.inputs.search_request.query},
|
||||
)
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_basic_graph(
|
||||
config: AgentSearchConfig,
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
@ -222,15 +166,16 @@ if __name__ == "__main__":
|
||||
# Joachim custom persona
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
config = get_test_config(db_session, primary_llm, fast_llm, search_request)
|
||||
assert (
|
||||
config.persistence is not None
|
||||
), "set a chat session id to run this test"
|
||||
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
config.use_agentic_persistence = True
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
config.behavior.perform_initial_search_decomposition = True
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
# with open("output.txt", "w") as f:
|
||||
|
@ -4,7 +4,7 @@ from langchain.schema import SystemMessage
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
|
||||
from onyx.agents.agent_search.models import AgentPromptEnrichmentComponents
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -80,11 +80,11 @@ def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -
|
||||
)
|
||||
|
||||
|
||||
def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
|
||||
prompt_builder = config.prompt_builder
|
||||
model = config.fast_llm
|
||||
def build_history_prompt(config: GraphConfig, question: str) -> str:
|
||||
prompt_builder = config.inputs.prompt_builder
|
||||
model = config.tooling.fast_llm
|
||||
persona_base = get_persona_agent_prompt_expressions(
|
||||
config.search_request.persona
|
||||
config.inputs.search_request.persona
|
||||
).base_prompt
|
||||
|
||||
if prompt_builder is None:
|
||||
@ -118,13 +118,13 @@ def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
|
||||
|
||||
|
||||
def get_prompt_enrichment_components(
|
||||
config: AgentSearchConfig,
|
||||
config: GraphConfig,
|
||||
) -> AgentPromptEnrichmentComponents:
|
||||
persona_prompts = get_persona_agent_prompt_expressions(
|
||||
config.search_request.persona
|
||||
config.inputs.search_request.persona
|
||||
)
|
||||
|
||||
history = build_history_prompt(config, config.search_request.query)
|
||||
history = build_history_prompt(config, config.inputs.search_request.query)
|
||||
|
||||
date_str = get_today_prompt()
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
@ -17,7 +18,11 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
@ -60,6 +65,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
@ -166,7 +172,7 @@ def get_test_config(
|
||||
fast_llm: LLM,
|
||||
search_request: SearchRequest,
|
||||
use_agentic_search: bool = True,
|
||||
) -> tuple[AgentSearchConfig, SearchTool]:
|
||||
) -> GraphConfig:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
@ -221,12 +227,8 @@ def get_test_config(
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
|
||||
config = AgentSearchConfig(
|
||||
graph_inputs = GraphInputs(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=HumanMessage(content=search_request.query),
|
||||
message_history=[],
|
||||
@ -234,17 +236,42 @@ def get_test_config(
|
||||
raw_user_query=search_request.query,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
|
||||
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
|
||||
message_id=1,
|
||||
use_agentic_persistence=True,
|
||||
db_session=db_session,
|
||||
tools=[search_tool],
|
||||
use_agentic_search=use_agentic_search,
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
)
|
||||
|
||||
return config, search_tool
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
primary_llm.config.model_provider, primary_llm.config.model_name
|
||||
)
|
||||
graph_tooling = GraphTooling(
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
|
||||
graph_persistence = None
|
||||
if chat_session_id := os.environ.get("ONYX_AS_CHAT_SESSION_ID"):
|
||||
graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=UUID(chat_session_id),
|
||||
message_id=1,
|
||||
)
|
||||
|
||||
search_behavior_config = GraphSearchConfig(
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=False,
|
||||
allow_refinement=True,
|
||||
)
|
||||
graph_config = GraphConfig(
|
||||
inputs=graph_inputs,
|
||||
tooling=graph_tooling,
|
||||
persistence=graph_persistence,
|
||||
behavior=search_behavior_config,
|
||||
)
|
||||
|
||||
return graph_config
|
||||
|
||||
|
||||
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:
|
||||
|
@ -4,7 +4,11 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@ -16,12 +20,10 @@ from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
@ -57,35 +59,9 @@ class Answer:
|
||||
use_agentic_persistence: bool = True,
|
||||
) -> None:
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
self.latest_query_files = latest_query_files or []
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
# used for QA flow where we only want to send a single message
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm.config.model_provider,
|
||||
model_name=llm.config.model_name,
|
||||
)
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: (list[AnswerPacket] | None) = None
|
||||
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
self.using_tool_calling_llm = (
|
||||
explicit_tool_calling_supported(
|
||||
self.llm.config.model_provider, self.llm.config.model_name
|
||||
)
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
@ -95,43 +71,46 @@ class Answer:
|
||||
elif len(search_tools) == 1:
|
||||
search_tool = search_tools[0]
|
||||
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
using_tool_calling_llm = (
|
||||
explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
self.agent_search_config = AgentSearchConfig(
|
||||
|
||||
self.graph_inputs = GraphInputs(
|
||||
search_request=search_request,
|
||||
prompt_builder=prompt_builder,
|
||||
files=latest_query_files,
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
)
|
||||
self.graph_tooling = GraphTooling(
|
||||
primary_llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
tools=tools or [],
|
||||
force_use_tool=force_use_tool,
|
||||
use_agentic_search=use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=current_agent_message_id,
|
||||
use_agentic_persistence=use_agentic_persistence,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
tools=tools,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
files=latest_query_files,
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
)
|
||||
self.graph_persistence = None
|
||||
if use_agentic_persistence:
|
||||
assert db_session, "db_session must be provided for agentic persistence"
|
||||
self.graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=current_agent_message_id,
|
||||
)
|
||||
self.search_behavior_config = GraphSearchConfig(
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
allow_refinement=True,
|
||||
)
|
||||
self.db_session = db_session
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
return self.tools
|
||||
|
||||
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
|
||||
|
||||
args_str = (
|
||||
f" with args='{self.force_use_tool.args}'"
|
||||
if self.force_use_tool.args
|
||||
else ""
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
tooling=self.graph_tooling,
|
||||
persistence=self.graph_persistence,
|
||||
behavior=self.search_behavior_config,
|
||||
)
|
||||
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
|
||||
return [tool]
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@ -141,11 +120,11 @@ class Answer:
|
||||
|
||||
run_langgraph = (
|
||||
run_main_graph
|
||||
if self.agent_search_config.use_agentic_search
|
||||
if self.graph_config.behavior.use_agentic_search
|
||||
else run_basic_graph
|
||||
)
|
||||
stream = run_langgraph(
|
||||
self.agent_search_config,
|
||||
self.graph_config,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
|
@ -5,10 +5,10 @@ import os
|
||||
|
||||
import yaml
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main__graph.graph_builder import (
|
||||
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main__graph.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
|
@ -70,13 +70,13 @@ def answer_instance(
|
||||
|
||||
|
||||
def test_basic_answer(answer_instance: Answer) -> None:
|
||||
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
|
||||
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="This is a "),
|
||||
AIMessageChunk(content="mock answer."),
|
||||
]
|
||||
answer_instance.agent_search_config.fast_llm = mock_llm
|
||||
answer_instance.agent_search_config.primary_llm = mock_llm
|
||||
answer_instance.graph_config.tooling.fast_llm = mock_llm
|
||||
answer_instance.graph_config.tooling.primary_llm = mock_llm
|
||||
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
assert len(output) == 2
|
||||
@ -128,11 +128,11 @@ def test_answer_with_search_call(
|
||||
force_use_tool: ForceUseTool,
|
||||
expected_tool_args: dict,
|
||||
) -> None:
|
||||
answer_instance.agent_search_config.tools = [mock_search_tool]
|
||||
answer_instance.agent_search_config.force_use_tool = force_use_tool
|
||||
answer_instance.graph_config.tooling.tools = [mock_search_tool]
|
||||
answer_instance.graph_config.tooling.force_use_tool = force_use_tool
|
||||
|
||||
# Set up the LLM mock to return search results and then an answer
|
||||
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
|
||||
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
|
||||
|
||||
stream_side_effect: list[list[BaseMessage]] = []
|
||||
|
||||
@ -253,10 +253,10 @@ def test_answer_with_search_no_tool_calling(
|
||||
mock_contexts: OnyxContexts,
|
||||
mock_search_tool: MagicMock,
|
||||
) -> None:
|
||||
answer_instance.agent_search_config.tools = [mock_search_tool]
|
||||
answer_instance.graph_config.tooling.tools = [mock_search_tool]
|
||||
|
||||
# Set up the LLM mock to return an answer
|
||||
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
|
||||
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="Based on the search results, "),
|
||||
AIMessageChunk(content="the answer is abc[1]. "),
|
||||
@ -264,7 +264,7 @@ def test_answer_with_search_no_tool_calling(
|
||||
]
|
||||
|
||||
# Force non-tool calling behavior
|
||||
answer_instance.agent_search_config.using_tool_calling_llm = False
|
||||
answer_instance.graph_config.tooling.using_tool_calling_llm = False
|
||||
|
||||
# Process the output
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
@ -319,7 +319,7 @@ def test_answer_with_search_no_tool_calling(
|
||||
|
||||
# Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool
|
||||
mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with(
|
||||
QUERY, [], answer_instance.llm
|
||||
QUERY, [], answer_instance.graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
# Verify that the search tool's run method was called
|
||||
@ -329,8 +329,8 @@ def test_answer_with_search_no_tool_calling(
|
||||
def test_is_cancelled(answer_instance: Answer) -> None:
|
||||
# Set up the LLM mock to return multiple chunks
|
||||
mock_llm = Mock()
|
||||
answer_instance.agent_search_config.primary_llm = mock_llm
|
||||
answer_instance.agent_search_config.fast_llm = mock_llm
|
||||
answer_instance.graph_config.tooling.primary_llm = mock_llm
|
||||
answer_instance.graph_config.tooling.fast_llm = mock_llm
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="This is the "),
|
||||
AIMessageChunk(content="first part."),
|
||||
|
Loading…
x
Reference in New Issue
Block a user