mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-19 09:33:53 +02:00
reworked config to have logical structure
This commit is contained in:
@ -109,7 +109,7 @@ if __name__ == "__main__":
|
|||||||
query="what can you do with onyx or danswer?",
|
query="what can you do with onyx or danswer?",
|
||||||
)
|
)
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
agent_search_config, search_tool = get_test_config(
|
graph_config, search_tool = get_test_config(
|
||||||
db_session, primary_llm, fast_llm, search_request
|
db_session, primary_llm, fast_llm, search_request
|
||||||
)
|
)
|
||||||
inputs = AnswerQuestionInput(
|
inputs = AnswerQuestionInput(
|
||||||
@ -119,7 +119,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
for thing in compiled_graph.stream(
|
for thing in compiled_graph.stream(
|
||||||
input=inputs,
|
input=inputs,
|
||||||
config={"configurable": {"config": agent_search_config}},
|
config={"configurable": {"config": graph_config}},
|
||||||
# debug=True,
|
# debug=True,
|
||||||
# subgraphs=True,
|
# subgraphs=True,
|
||||||
):
|
):
|
||||||
|
@ -11,7 +11,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
|||||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||||
QACheckUpdate,
|
QACheckUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
@ -47,8 +47,8 @@ def check_sub_answer(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
agent_searchch_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
fast_llm = agent_searchch_config.fast_llm
|
fast_llm = graph_config.tooling.fast_llm
|
||||||
response = list(
|
response = list(
|
||||||
fast_llm.stream(
|
fast_llm.stream(
|
||||||
prompt=msg,
|
prompt=msg,
|
||||||
|
@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
|||||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||||
QAGenerationUpdate,
|
QAGenerationUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
build_sub_question_answer_prompt,
|
build_sub_question_answer_prompt,
|
||||||
)
|
)
|
||||||
@ -42,13 +42,13 @@ def generate_sub_answer(
|
|||||||
) -> QAGenerationUpdate:
|
) -> QAGenerationUpdate:
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
question = state.question
|
question = state.question
|
||||||
state.verified_reranked_documents
|
state.verified_reranked_documents
|
||||||
level, question_nr = parse_question_id(state.question_id)
|
level, question_nr = parse_question_id(state.question_id)
|
||||||
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
||||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||||
agent_search_config.search_request.persona
|
graph_config.inputs.search_request.persona
|
||||||
).contextualized_prompt
|
).contextualized_prompt
|
||||||
|
|
||||||
if len(context_docs) == 0:
|
if len(context_docs) == 0:
|
||||||
@ -64,10 +64,10 @@ def generate_sub_answer(
|
|||||||
writer,
|
writer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
fast_llm = agent_search_config.fast_llm
|
fast_llm = graph_config.tooling.fast_llm
|
||||||
msg = build_sub_question_answer_prompt(
|
msg = build_sub_question_answer_prompt(
|
||||||
question=question,
|
question=question,
|
||||||
original_question=agent_search_config.search_request.query,
|
original_question=graph_config.inputs.search_request.query,
|
||||||
docs=context_docs,
|
docs=context_docs,
|
||||||
persona_specification=persona_contextualized_prompt,
|
persona_specification=persona_contextualized_prompt,
|
||||||
config=fast_llm.config,
|
config=fast_llm.config,
|
||||||
|
@ -19,7 +19,7 @@ from onyx.agents.agent_search.deep_search.main.operations import logger
|
|||||||
from onyx.agents.agent_search.deep_search.main.states import (
|
from onyx.agents.agent_search.deep_search.main.states import (
|
||||||
InitialAnswerUpdate,
|
InitialAnswerUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
get_prompt_enrichment_components,
|
get_prompt_enrichment_components,
|
||||||
)
|
)
|
||||||
@ -63,9 +63,9 @@ def generate_initial_answer(
|
|||||||
) -> InitialAnswerUpdate:
|
) -> InitialAnswerUpdate:
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
question = agent_search_config.search_request.query
|
question = graph_config.inputs.search_request.query
|
||||||
prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config)
|
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||||
|
|
||||||
sub_questions_cited_documents = state.cited_documents
|
sub_questions_cited_documents = state.cited_documents
|
||||||
orig_question_retrieval_documents = state.orig_question_retrieval_documents
|
orig_question_retrieval_documents = state.orig_question_retrieval_documents
|
||||||
@ -93,8 +93,9 @@ def generate_initial_answer(
|
|||||||
# Use the query info from the base document retrieval
|
# Use the query info from the base document retrieval
|
||||||
query_info = get_query_info(state.orig_question_query_retrieval_results)
|
query_info = get_query_info(state.orig_question_query_retrieval_results)
|
||||||
|
|
||||||
if agent_search_config.search_tool is None:
|
assert (
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
graph_config.tooling.search_tool
|
||||||
|
), "search_tool must be provided for agentic search"
|
||||||
|
|
||||||
relevance_list = relevance_from_docs(relevant_docs)
|
relevance_list = relevance_from_docs(relevant_docs)
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
@ -103,7 +104,7 @@ def generate_initial_answer(
|
|||||||
final_context_sections=relevant_docs,
|
final_context_sections=relevant_docs,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: relevance_list,
|
get_section_relevance=lambda: relevance_list,
|
||||||
search_tool=agent_search_config.search_tool,
|
search_tool=graph_config.tooling.search_tool,
|
||||||
):
|
):
|
||||||
write_custom_event(
|
write_custom_event(
|
||||||
"tool_response",
|
"tool_response",
|
||||||
@ -167,7 +168,7 @@ def generate_initial_answer(
|
|||||||
sub_question_answer_str = ""
|
sub_question_answer_str = ""
|
||||||
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||||
|
|
||||||
model = agent_search_config.fast_llm
|
model = graph_config.tooling.fast_llm
|
||||||
|
|
||||||
doc_context = format_docs(relevant_docs)
|
doc_context = format_docs(relevant_docs)
|
||||||
doc_context = trim_prompt_piece(
|
doc_context = trim_prompt_piece(
|
||||||
|
@ -18,7 +18,7 @@ from onyx.agents.agent_search.deep_search.main.operations import (
|
|||||||
from onyx.agents.agent_search.deep_search.main.states import (
|
from onyx.agents.agent_search.deep_search.main.states import (
|
||||||
InitialQuestionDecompositionUpdate,
|
InitialQuestionDecompositionUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
build_history_prompt,
|
build_history_prompt,
|
||||||
)
|
)
|
||||||
@ -44,25 +44,18 @@ def decompose_orig_question(
|
|||||||
) -> InitialQuestionDecompositionUpdate:
|
) -> InitialQuestionDecompositionUpdate:
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
question = agent_search_config.search_request.query
|
question = graph_config.inputs.search_request.query
|
||||||
chat_session_id = agent_search_config.chat_session_id
|
|
||||||
primary_message_id = agent_search_config.message_id
|
|
||||||
perform_initial_search_decomposition = (
|
perform_initial_search_decomposition = (
|
||||||
agent_search_config.perform_initial_search_decomposition
|
graph_config.behavior.perform_initial_search_decomposition
|
||||||
)
|
)
|
||||||
# Get the rewritten queries in a defined format
|
# Get the rewritten queries in a defined format
|
||||||
model = agent_search_config.fast_llm
|
model = graph_config.tooling.fast_llm
|
||||||
|
|
||||||
history = build_history_prompt(agent_search_config, question)
|
history = build_history_prompt(graph_config, question)
|
||||||
|
|
||||||
# Use the initial search results to inform the decomposition
|
# Use the initial search results to inform the decomposition
|
||||||
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
|
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
|
||||||
|
|
||||||
if not chat_session_id or not primary_message_id:
|
|
||||||
raise ValueError(
|
|
||||||
"chat_session_id and message_id must be provided for agent search"
|
|
||||||
)
|
|
||||||
agent_start_time = datetime.now()
|
agent_start_time = datetime.now()
|
||||||
|
|
||||||
# Initial search to inform decomposition. Just get top 3 fits
|
# Initial search to inform decomposition. Just get top 3 fits
|
||||||
|
@ -6,7 +6,7 @@ from onyx.agents.agent_search.core_state import CoreState
|
|||||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||||
ExpandedRetrievalInput,
|
ExpandedRetrievalInput,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -16,9 +16,9 @@ def format_orig_question_search_input(
|
|||||||
state: CoreState, config: RunnableConfig
|
state: CoreState, config: RunnableConfig
|
||||||
) -> ExpandedRetrievalInput:
|
) -> ExpandedRetrievalInput:
|
||||||
logger.debug("generate_raw_search_data")
|
logger.debug("generate_raw_search_data")
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
return ExpandedRetrievalInput(
|
return ExpandedRetrievalInput(
|
||||||
question=agent_search_config.search_request.query,
|
question=graph_config.inputs.search_request.query,
|
||||||
base_search=True,
|
base_search=True,
|
||||||
sub_question_id=None, # This graph is always and only used for the original question
|
sub_question_id=None, # This graph is always and only used for the original question
|
||||||
log_messages=[],
|
log_messages=[],
|
||||||
|
@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
|||||||
from onyx.agents.agent_search.deep_search.main.states import (
|
from onyx.agents.agent_search.deep_search.main.states import (
|
||||||
RequireRefinementUpdate,
|
RequireRefinementUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
@ -26,12 +26,12 @@ logger = setup_logger()
|
|||||||
def route_initial_tool_choice(
|
def route_initial_tool_choice(
|
||||||
state: MainState, config: RunnableConfig
|
state: MainState, config: RunnableConfig
|
||||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
if state.tool_choice is not None:
|
if state.tool_choice is not None:
|
||||||
if (
|
if (
|
||||||
agent_config.use_agentic_search
|
agent_config.behavior.use_agentic_search
|
||||||
and agent_config.search_tool is not None
|
and agent_config.tooling.search_tool is not None
|
||||||
and state.tool_choice.tool.name == agent_config.search_tool.name
|
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
|
||||||
):
|
):
|
||||||
return "start_agent_search"
|
return "start_agent_search"
|
||||||
else:
|
else:
|
||||||
|
@ -221,17 +221,17 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
search_request = SearchRequest(query="Who created Excel?")
|
search_request = SearchRequest(query="Who created Excel?")
|
||||||
agent_search_config, search_tool = get_test_config(
|
graph_config = get_test_config(
|
||||||
db_session, primary_llm, fast_llm, search_request
|
db_session, primary_llm, fast_llm, search_request
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs = MainInput(
|
inputs = MainInput(
|
||||||
base_question=agent_search_config.search_request.query, log_messages=[]
|
base_question=graph_config.inputs.search_request.query, log_messages=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
for thing in compiled_graph.stream(
|
for thing in compiled_graph.stream(
|
||||||
input=inputs,
|
input=inputs,
|
||||||
config={"configurable": {"config": agent_search_config}},
|
config={"configurable": {"config": graph_config}},
|
||||||
# stream_mode="debug",
|
# stream_mode="debug",
|
||||||
# debug=True,
|
# debug=True,
|
||||||
subgraphs=True,
|
subgraphs=True,
|
||||||
|
@ -9,7 +9,7 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
|||||||
InitialVRefinedAnswerComparisonUpdate,
|
InitialVRefinedAnswerComparisonUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
get_langgraph_node_log_string,
|
get_langgraph_node_log_string,
|
||||||
@ -23,8 +23,8 @@ def compare_answers(
|
|||||||
) -> InitialVRefinedAnswerComparisonUpdate:
|
) -> InitialVRefinedAnswerComparisonUpdate:
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
question = agent_search_config.search_request.query
|
question = graph_config.inputs.search_request.query
|
||||||
initial_answer = state.initial_answer
|
initial_answer = state.initial_answer
|
||||||
refined_answer = state.refined_answer
|
refined_answer = state.refined_answer
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ def compare_answers(
|
|||||||
msg = [HumanMessage(content=compare_answers_prompt)]
|
msg = [HumanMessage(content=compare_answers_prompt)]
|
||||||
|
|
||||||
# Get the rewritten queries in a defined format
|
# Get the rewritten queries in a defined format
|
||||||
model = agent_search_config.fast_llm
|
model = graph_config.tooling.fast_llm
|
||||||
|
|
||||||
# no need to stream this
|
# no need to stream this
|
||||||
resp = model.invoke(msg)
|
resp = model.invoke(msg)
|
||||||
|
@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
|||||||
from onyx.agents.agent_search.deep_search.main.states import (
|
from onyx.agents.agent_search.deep_search.main.states import (
|
||||||
RefinedQuestionDecompositionUpdate,
|
RefinedQuestionDecompositionUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
build_history_prompt,
|
build_history_prompt,
|
||||||
)
|
)
|
||||||
@ -39,13 +39,13 @@ def create_refined_sub_questions(
|
|||||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||||
) -> RefinedQuestionDecompositionUpdate:
|
) -> RefinedQuestionDecompositionUpdate:
|
||||||
""" """
|
""" """
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
write_custom_event(
|
write_custom_event(
|
||||||
"start_refined_answer_creation",
|
"start_refined_answer_creation",
|
||||||
ToolCallKickoff(
|
ToolCallKickoff(
|
||||||
tool_name="agent_search_1",
|
tool_name="agent_search_1",
|
||||||
tool_args={
|
tool_args={
|
||||||
"query": agent_search_config.search_request.query,
|
"query": graph_config.inputs.search_request.query,
|
||||||
"answer": state.initial_answer,
|
"answer": state.initial_answer,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@ -56,9 +56,9 @@ def create_refined_sub_questions(
|
|||||||
|
|
||||||
agent_refined_start_time = datetime.now()
|
agent_refined_start_time = datetime.now()
|
||||||
|
|
||||||
question = agent_search_config.search_request.query
|
question = graph_config.inputs.search_request.query
|
||||||
base_answer = state.initial_answer
|
base_answer = state.initial_answer
|
||||||
history = build_history_prompt(agent_search_config, question)
|
history = build_history_prompt(graph_config, question)
|
||||||
# get the entity term extraction dict and properly format it
|
# get the entity term extraction dict and properly format it
|
||||||
entity_retlation_term_extractions = state.entity_relation_term_extractions
|
entity_retlation_term_extractions = state.entity_relation_term_extractions
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ def create_refined_sub_questions(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Grader
|
# Grader
|
||||||
model = agent_search_config.fast_llm
|
model = graph_config.tooling.fast_llm
|
||||||
|
|
||||||
streamed_tokens = dispatch_separated(
|
streamed_tokens = dispatch_separated(
|
||||||
model.stream(msg), dispatch_subquestion(1, writer)
|
model.stream(msg), dispatch_subquestion(1, writer)
|
||||||
|
@ -7,7 +7,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
|||||||
from onyx.agents.agent_search.deep_search.main.states import (
|
from onyx.agents.agent_search.deep_search.main.states import (
|
||||||
RequireRefinementUpdate,
|
RequireRefinementUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
get_langgraph_node_log_string,
|
get_langgraph_node_log_string,
|
||||||
)
|
)
|
||||||
@ -18,7 +18,7 @@ def decide_refinement_need(
|
|||||||
) -> RequireRefinementUpdate:
|
) -> RequireRefinementUpdate:
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
|
|
||||||
decision = True # TODO: just for current testing purposes
|
decision = True # TODO: just for current testing purposes
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ def decide_refinement_need(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if agent_search_config.allow_refinement:
|
if graph_config.behavior.allow_refinement:
|
||||||
return RequireRefinementUpdate(
|
return RequireRefinementUpdate(
|
||||||
require_refined_answer_eval=decision,
|
require_refined_answer_eval=decision,
|
||||||
log_messages=log_messages,
|
log_messages=log_messages,
|
||||||
|
@ -11,7 +11,7 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
|||||||
EntityTermExtractionUpdate,
|
EntityTermExtractionUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
trim_prompt_piece,
|
trim_prompt_piece,
|
||||||
)
|
)
|
||||||
@ -33,8 +33,8 @@ def extract_entities_terms(
|
|||||||
) -> EntityTermExtractionUpdate:
|
) -> EntityTermExtractionUpdate:
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
if not agent_search_config.allow_refinement:
|
if not graph_config.behavior.allow_refinement:
|
||||||
return EntityTermExtractionUpdate(
|
return EntityTermExtractionUpdate(
|
||||||
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
||||||
entities=[],
|
entities=[],
|
||||||
@ -52,21 +52,21 @@ def extract_entities_terms(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# first four lines duplicates from generate_initial_answer
|
# first four lines duplicates from generate_initial_answer
|
||||||
question = agent_search_config.search_request.query
|
question = graph_config.inputs.search_request.query
|
||||||
initial_search_docs = state.exploratory_search_results[:15]
|
initial_search_docs = state.exploratory_search_results[:15]
|
||||||
|
|
||||||
# start with the entity/term/extraction
|
# start with the entity/term/extraction
|
||||||
doc_context = format_docs(initial_search_docs)
|
doc_context = format_docs(initial_search_docs)
|
||||||
|
|
||||||
doc_context = trim_prompt_piece(
|
doc_context = trim_prompt_piece(
|
||||||
agent_search_config.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
graph_config.tooling.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
||||||
)
|
)
|
||||||
msg = [
|
msg = [
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
fast_llm = agent_search_config.fast_llm
|
fast_llm = graph_config.tooling.fast_llm
|
||||||
# Grader
|
# Grader
|
||||||
llm_response = fast_llm.invoke(
|
llm_response = fast_llm.invoke(
|
||||||
prompt=msg,
|
prompt=msg,
|
||||||
|
@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
|||||||
from onyx.agents.agent_search.deep_search.main.states import (
|
from onyx.agents.agent_search.deep_search.main.states import (
|
||||||
RefinedAnswerUpdate,
|
RefinedAnswerUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
get_prompt_enrichment_components,
|
get_prompt_enrichment_components,
|
||||||
)
|
)
|
||||||
@ -61,9 +61,9 @@ def generate_refined_answer(
|
|||||||
) -> RefinedAnswerUpdate:
|
) -> RefinedAnswerUpdate:
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
question = agent_search_config.search_request.query
|
question = graph_config.inputs.search_request.query
|
||||||
prompt_enrichment_components = get_prompt_enrichment_components(agent_search_config)
|
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||||
|
|
||||||
persona_contextualized_prompt = (
|
persona_contextualized_prompt = (
|
||||||
prompt_enrichment_components.persona_prompts.contextualized_prompt
|
prompt_enrichment_components.persona_prompts.contextualized_prompt
|
||||||
@ -93,8 +93,9 @@ def generate_refined_answer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
query_info = get_query_info(state.orig_question_query_retrieval_results)
|
query_info = get_query_info(state.orig_question_query_retrieval_results)
|
||||||
if agent_search_config.search_tool is None:
|
assert (
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
graph_config.tooling.search_tool
|
||||||
|
), "search_tool must be provided for agentic search"
|
||||||
# stream refined answer docs
|
# stream refined answer docs
|
||||||
relevance_list = relevance_from_docs(relevant_docs)
|
relevance_list = relevance_from_docs(relevant_docs)
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
@ -103,7 +104,7 @@ def generate_refined_answer(
|
|||||||
final_context_sections=relevant_docs,
|
final_context_sections=relevant_docs,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: relevance_list,
|
get_section_relevance=lambda: relevance_list,
|
||||||
search_tool=agent_search_config.search_tool,
|
search_tool=graph_config.tooling.search_tool,
|
||||||
):
|
):
|
||||||
write_custom_event(
|
write_custom_event(
|
||||||
"tool_response",
|
"tool_response",
|
||||||
@ -189,7 +190,7 @@ def generate_refined_answer(
|
|||||||
else:
|
else:
|
||||||
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
|
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||||
|
|
||||||
model = agent_search_config.fast_llm
|
model = graph_config.tooling.fast_llm
|
||||||
relevant_docs_str = format_docs(relevant_docs)
|
relevant_docs_str = format_docs(relevant_docs)
|
||||||
relevant_docs_str = trim_prompt_piece(
|
relevant_docs_str = trim_prompt_piece(
|
||||||
model.config,
|
model.config,
|
||||||
|
@ -10,7 +10,7 @@ from onyx.agents.agent_search.deep_search.main.models import AgentTimings
|
|||||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||||
from onyx.agents.agent_search.deep_search.main.states import MainOutput
|
from onyx.agents.agent_search.deep_search.main.states import MainOutput
|
||||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
|
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
get_langgraph_node_log_string,
|
get_langgraph_node_log_string,
|
||||||
@ -59,21 +59,23 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
|
|||||||
)
|
)
|
||||||
|
|
||||||
persona_id = None
|
persona_id = None
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
if agent_search_config.search_request.persona:
|
if graph_config.inputs.search_request.persona:
|
||||||
persona_id = agent_search_config.search_request.persona.id
|
persona_id = graph_config.inputs.search_request.persona.id
|
||||||
|
|
||||||
user_id = None
|
user_id = None
|
||||||
if agent_search_config.search_tool is not None:
|
assert (
|
||||||
user = agent_search_config.search_tool.user
|
graph_config.tooling.search_tool
|
||||||
|
), "search_tool must be provided for agentic search"
|
||||||
|
user = graph_config.tooling.search_tool.user
|
||||||
if user:
|
if user:
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
|
|
||||||
# log the agent metrics
|
# log the agent metrics
|
||||||
if agent_search_config.db_session is not None:
|
if graph_config.persistence:
|
||||||
if agent_base_duration is not None:
|
if agent_base_duration is not None:
|
||||||
log_agent_metrics(
|
log_agent_metrics(
|
||||||
db_session=agent_search_config.db_session,
|
db_session=graph_config.persistence.db_session,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
persona_id=persona_id,
|
persona_id=persona_id,
|
||||||
agent_type=agent_type,
|
agent_type=agent_type,
|
||||||
@ -81,11 +83,10 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
|
|||||||
agent_metrics=combined_agent_metrics,
|
agent_metrics=combined_agent_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
if agent_search_config.use_agentic_persistence:
|
|
||||||
# Persist the sub-answer in the database
|
# Persist the sub-answer in the database
|
||||||
db_session = agent_search_config.db_session
|
db_session = graph_config.persistence.db_session
|
||||||
chat_session_id = agent_search_config.chat_session_id
|
chat_session_id = graph_config.persistence.chat_session_id
|
||||||
primary_message_id = agent_search_config.message_id
|
primary_message_id = graph_config.persistence.message_id
|
||||||
sub_question_answer_results = state.sub_question_results
|
sub_question_answer_results = state.sub_question_results
|
||||||
|
|
||||||
log_agent_sub_question_results(
|
log_agent_sub_question_results(
|
||||||
|
@ -7,7 +7,7 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
|||||||
ExploratorySearchUpdate,
|
ExploratorySearchUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
build_history_prompt,
|
build_history_prompt,
|
||||||
)
|
)
|
||||||
@ -24,24 +24,14 @@ def start_agent_search(
|
|||||||
) -> ExploratorySearchUpdate:
|
) -> ExploratorySearchUpdate:
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
question = agent_search_config.search_request.query
|
question = graph_config.inputs.search_request.query
|
||||||
chat_session_id = agent_search_config.chat_session_id
|
|
||||||
primary_message_id = agent_search_config.message_id
|
|
||||||
agent_search_config.fast_llm
|
|
||||||
|
|
||||||
history = build_history_prompt(agent_search_config, question)
|
history = build_history_prompt(graph_config, question)
|
||||||
|
|
||||||
if chat_session_id is None or primary_message_id is None:
|
|
||||||
raise ValueError(
|
|
||||||
"chat_session_id and message_id must be provided for agent search"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initial search to inform decomposition. Just get top 3 fits
|
# Initial search to inform decomposition. Just get top 3 fits
|
||||||
|
search_tool = graph_config.tooling.search_tool
|
||||||
search_tool = agent_search_config.search_tool
|
assert search_tool, "search_tool must be provided for agentic search"
|
||||||
if search_tool is None:
|
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
|
||||||
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
|
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
|
||||||
|
|
||||||
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
|
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
|
||||||
|
@ -10,15 +10,15 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
|||||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||||
RetrievalInput,
|
RetrievalInput,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
|
|
||||||
|
|
||||||
def parallel_retrieval_edge(
|
def parallel_retrieval_edge(
|
||||||
state: ExpandedRetrievalState, config: RunnableConfig
|
state: ExpandedRetrievalState, config: RunnableConfig
|
||||||
) -> list[Send | Hashable]:
|
) -> list[Send | Hashable]:
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
question = (
|
question = (
|
||||||
state.question if state.question else agent_search_config.search_request.query
|
state.question if state.question else graph_config.inputs.search_request.query
|
||||||
)
|
)
|
||||||
|
|
||||||
query_expansions = state.expanded_queries + [question]
|
query_expansions = state.expanded_queries + [question]
|
||||||
|
@ -129,7 +129,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
agent_search_config, search_tool = get_test_config(
|
graph_config, search_tool = get_test_config(
|
||||||
db_session, primary_llm, fast_llm, search_request
|
db_session, primary_llm, fast_llm, search_request
|
||||||
)
|
)
|
||||||
inputs = ExpandedRetrievalInput(
|
inputs = ExpandedRetrievalInput(
|
||||||
@ -140,7 +140,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
for thing in compiled_graph.stream(
|
for thing in compiled_graph.stream(
|
||||||
input=inputs,
|
input=inputs,
|
||||||
config={"configurable": {"config": agent_search_config}},
|
config={"configurable": {"config": graph_config}},
|
||||||
# debug=True,
|
# debug=True,
|
||||||
subgraphs=True,
|
subgraphs=True,
|
||||||
):
|
):
|
||||||
|
@ -15,7 +15,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
|||||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||||
QueryExpansionUpdate,
|
QueryExpansionUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||||
)
|
)
|
||||||
@ -34,21 +34,17 @@ def expand_queries(
|
|||||||
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
|
||||||
# When we are running this node on the original question, no question is explictly passed in.
|
# When we are running this node on the original question, no question is explictly passed in.
|
||||||
# Instead, we use the original question from the search request.
|
# Instead, we use the original question from the search request.
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
question = state.question
|
question = state.question
|
||||||
|
|
||||||
llm = agent_search_config.fast_llm
|
llm = graph_config.tooling.fast_llm
|
||||||
chat_session_id = agent_search_config.chat_session_id
|
|
||||||
sub_question_id = state.sub_question_id
|
sub_question_id = state.sub_question_id
|
||||||
if sub_question_id is None:
|
if sub_question_id is None:
|
||||||
level, question_nr = 0, 0
|
level, question_nr = 0, 0
|
||||||
else:
|
else:
|
||||||
level, question_nr = parse_question_id(sub_question_id)
|
level, question_nr = parse_question_id(sub_question_id)
|
||||||
|
|
||||||
if chat_session_id is None:
|
|
||||||
raise ValueError("chat_session_id must be provided for agent search")
|
|
||||||
|
|
||||||
msg = [
|
msg = [
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||||
|
@ -16,7 +16,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
|||||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||||
ExpandedRetrievalUpdate,
|
ExpandedRetrievalUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||||
@ -33,7 +33,7 @@ def format_results(
|
|||||||
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
level, question_nr = parse_question_id(state.sub_question_id or "0_0")
|
||||||
query_info = get_query_info(state.query_retrieval_results)
|
query_info = get_query_info(state.query_retrieval_results)
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
# main question docs will be sent later after aggregation and deduping with sub-question docs
|
||||||
|
|
||||||
reranked_documents = state.reranked_documents
|
reranked_documents = state.reranked_documents
|
||||||
@ -44,8 +44,9 @@ def format_results(
|
|||||||
# the top 3 for that one. We may want to revisit this.
|
# the top 3 for that one. We may want to revisit this.
|
||||||
reranked_documents = state.query_retrieval_results[-1].search_results[:3]
|
reranked_documents = state.query_retrieval_results[-1].search_results[:3]
|
||||||
|
|
||||||
if agent_search_config.search_tool is None:
|
assert (
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
graph_config.tooling.search_tool
|
||||||
|
), "search_tool must be provided for agentic search"
|
||||||
|
|
||||||
relevance_list = relevance_from_docs(reranked_documents)
|
relevance_list = relevance_from_docs(reranked_documents)
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
@ -54,7 +55,7 @@ def format_results(
|
|||||||
final_context_sections=reranked_documents,
|
final_context_sections=reranked_documents,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: relevance_list,
|
get_section_relevance=lambda: relevance_list,
|
||||||
search_tool=agent_search_config.search_tool,
|
search_tool=graph_config.tooling.search_tool,
|
||||||
):
|
):
|
||||||
write_custom_event(
|
write_custom_event(
|
||||||
"tool_response",
|
"tool_response",
|
||||||
|
@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
|||||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||||
ExpandedRetrievalState,
|
ExpandedRetrievalState,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
@ -36,12 +36,13 @@ def rerank_documents(
|
|||||||
# Rerank post retrieval and verification. First, create a search query
|
# Rerank post retrieval and verification. First, create a search query
|
||||||
# then create the list of reranked sections
|
# then create the list of reranked sections
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
question = (
|
question = (
|
||||||
state.question if state.question else agent_search_config.search_request.query
|
state.question if state.question else graph_config.inputs.search_request.query
|
||||||
)
|
)
|
||||||
if agent_search_config.search_tool is None:
|
assert (
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
graph_config.tooling.search_tool
|
||||||
|
), "search_tool must be provided for agentic search"
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
# we ignore some of the user specified fields since this search is
|
# we ignore some of the user specified fields since this search is
|
||||||
# internal to agentic search, but we still want to pass through
|
# internal to agentic search, but we still want to pass through
|
||||||
@ -49,13 +50,13 @@ def rerank_documents(
|
|||||||
# (to not make an unnecessary db call).
|
# (to not make an unnecessary db call).
|
||||||
search_request = SearchRequest(
|
search_request = SearchRequest(
|
||||||
query=question,
|
query=question,
|
||||||
persona=agent_search_config.search_request.persona,
|
persona=graph_config.inputs.search_request.persona,
|
||||||
rerank_settings=agent_search_config.search_request.rerank_settings,
|
rerank_settings=graph_config.inputs.search_request.rerank_settings,
|
||||||
)
|
)
|
||||||
_search_query = retrieval_preprocessing(
|
_search_query = retrieval_preprocessing(
|
||||||
search_request=search_request,
|
search_request=search_request,
|
||||||
user=agent_search_config.search_tool.user, # bit of a hack
|
user=graph_config.tooling.search_tool.user, # bit of a hack
|
||||||
llm=agent_search_config.fast_llm,
|
llm=graph_config.tooling.fast_llm,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
|||||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||||
RetrievalInput,
|
RetrievalInput,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
from onyx.agents.agent_search.shared_graph_utils.models import QueryResult
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
@ -45,8 +45,8 @@ def retrieve_documents(
|
|||||||
"""
|
"""
|
||||||
node_start_time = datetime.now()
|
node_start_time = datetime.now()
|
||||||
query_to_retrieve = state.query_to_retrieve
|
query_to_retrieve = state.query_to_retrieve
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
search_tool = agent_search_config.search_tool
|
search_tool = graph_config.tooling.search_tool
|
||||||
|
|
||||||
retrieved_docs: list[InferenceSection] = []
|
retrieved_docs: list[InferenceSection] = []
|
||||||
if not query_to_retrieve.strip():
|
if not query_to_retrieve.strip():
|
||||||
|
@ -9,7 +9,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
|||||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||||
DocVerificationUpdate,
|
DocVerificationUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
trim_prompt_piece,
|
trim_prompt_piece,
|
||||||
)
|
)
|
||||||
@ -34,8 +34,8 @@ def verify_documents(
|
|||||||
retrieved_document_to_verify = state.retrieved_document_to_verify
|
retrieved_document_to_verify = state.retrieved_document_to_verify
|
||||||
document_content = retrieved_document_to_verify.combined_content
|
document_content = retrieved_document_to_verify.combined_content
|
||||||
|
|
||||||
agent_search_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
fast_llm = agent_search_config.fast_llm
|
fast_llm = graph_config.tooling.fast_llm
|
||||||
|
|
||||||
document_content = trim_prompt_piece(
|
document_content = trim_prompt_piece(
|
||||||
fast_llm.config, document_content, VERIFIER_PROMPT + question
|
fast_llm.config, document_content, VERIFIER_PROMPT + question
|
||||||
|
@ -14,85 +14,15 @@ from onyx.tools.tool import Tool
|
|||||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
|
|
||||||
|
|
||||||
class AgentSearchConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
Configuration for the Agent Search feature.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# The search request that was used to generate the Agent Search
|
|
||||||
search_request: SearchRequest
|
|
||||||
|
|
||||||
primary_llm: LLM
|
|
||||||
fast_llm: LLM
|
|
||||||
|
|
||||||
# Whether to force use of a tool, or to
|
|
||||||
# force tool args IF the tool is used
|
|
||||||
force_use_tool: ForceUseTool
|
|
||||||
|
|
||||||
# contains message history for the current chat session
|
|
||||||
# has the following (at most one is non-None)
|
|
||||||
# message_history: list[PreviousMessage] | None = None
|
|
||||||
# single_message_history: str | None = None
|
|
||||||
prompt_builder: AnswerPromptBuilder
|
|
||||||
|
|
||||||
search_tool: SearchTool | None = None
|
|
||||||
|
|
||||||
use_agentic_search: bool = False
|
|
||||||
|
|
||||||
# For persisting agent search data
|
|
||||||
chat_session_id: UUID | None = None
|
|
||||||
|
|
||||||
# The message ID of the user message that triggered the Pro Search
|
|
||||||
message_id: int | None = None
|
|
||||||
|
|
||||||
# Whether to persist data for Agentic Search
|
|
||||||
use_agentic_persistence: bool = True
|
|
||||||
|
|
||||||
# The database session for Agentic Search
|
|
||||||
db_session: Session | None = None
|
|
||||||
|
|
||||||
# Whether to perform initial search to inform decomposition
|
|
||||||
# perform_initial_search_path_decision: bool = True
|
|
||||||
|
|
||||||
# Whether to perform initial search to inform decomposition
|
|
||||||
perform_initial_search_decomposition: bool = True
|
|
||||||
|
|
||||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
|
||||||
allow_refinement: bool = True
|
|
||||||
|
|
||||||
# Tools available for use
|
|
||||||
tools: list[Tool] | None = None
|
|
||||||
|
|
||||||
using_tool_calling_llm: bool = False
|
|
||||||
|
|
||||||
files: list[InMemoryChatFile] | None = None
|
|
||||||
|
|
||||||
structured_response_format: dict | None = None
|
|
||||||
|
|
||||||
skip_gen_ai_answer_generation: bool = False
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_db_session(self) -> "AgentSearchConfig":
|
|
||||||
if self.use_agentic_persistence and self.db_session is None:
|
|
||||||
raise ValueError(
|
|
||||||
"db_session must be provided for pro search when using persistence"
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_search_tool(self) -> "AgentSearchConfig":
|
|
||||||
if self.use_agentic_search and self.search_tool is None:
|
|
||||||
raise ValueError("search_tool must be provided for agentic search")
|
|
||||||
return self
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
class GraphInputs(BaseModel):
|
class GraphInputs(BaseModel):
|
||||||
"""Input data required for the graph execution"""
|
"""Input data required for the graph execution"""
|
||||||
|
|
||||||
search_request: SearchRequest
|
search_request: SearchRequest
|
||||||
|
# contains message history for the current chat session
|
||||||
|
# has the following (at most one is non-None)
|
||||||
|
# TODO: unify this into a single message history
|
||||||
|
# message_history: list[PreviousMessage] | None = None
|
||||||
|
# single_message_history: str | None = None
|
||||||
prompt_builder: AnswerPromptBuilder
|
prompt_builder: AnswerPromptBuilder
|
||||||
files: list[InMemoryChatFile] | None = None
|
files: list[InMemoryChatFile] | None = None
|
||||||
structured_response_format: dict | None = None
|
structured_response_format: dict | None = None
|
||||||
@ -107,7 +37,9 @@ class GraphTooling(BaseModel):
|
|||||||
primary_llm: LLM
|
primary_llm: LLM
|
||||||
fast_llm: LLM
|
fast_llm: LLM
|
||||||
search_tool: SearchTool | None = None
|
search_tool: SearchTool | None = None
|
||||||
tools: list[Tool] | None = None
|
tools: list[Tool]
|
||||||
|
# Whether to force use of a tool, or to
|
||||||
|
# force tool args IF the tool is used
|
||||||
force_use_tool: ForceUseTool
|
force_use_tool: ForceUseTool
|
||||||
using_tool_calling_llm: bool = False
|
using_tool_calling_llm: bool = False
|
||||||
|
|
||||||
@ -118,41 +50,41 @@ class GraphTooling(BaseModel):
|
|||||||
class GraphPersistence(BaseModel):
|
class GraphPersistence(BaseModel):
|
||||||
"""Configuration for data persistence"""
|
"""Configuration for data persistence"""
|
||||||
|
|
||||||
chat_session_id: UUID | None = None
|
chat_session_id: UUID
|
||||||
message_id: int | None = None
|
# The message ID of the to-be-created first agent message
|
||||||
use_agentic_persistence: bool = True
|
# in response to the user message that triggered the Pro Search
|
||||||
db_session: Session | None = None
|
message_id: int
|
||||||
|
|
||||||
|
# The database session the user and initial agent
|
||||||
|
# message were flushed to; only needed for agentic search
|
||||||
|
db_session: Session
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_db_session(self) -> "GraphPersistence":
|
|
||||||
if self.use_agentic_persistence and self.db_session is None:
|
|
||||||
raise ValueError(
|
|
||||||
"db_session must be provided for pro search when using persistence"
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
class GraphSearchConfig(BaseModel):
|
||||||
class SearchBehaviorConfig(BaseModel):
|
|
||||||
"""Configuration controlling search behavior"""
|
"""Configuration controlling search behavior"""
|
||||||
|
|
||||||
use_agentic_search: bool = False
|
use_agentic_search: bool = False
|
||||||
|
# Whether to perform initial search to inform decomposition
|
||||||
perform_initial_search_decomposition: bool = True
|
perform_initial_search_decomposition: bool = True
|
||||||
|
|
||||||
|
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||||
allow_refinement: bool = True
|
allow_refinement: bool = True
|
||||||
skip_gen_ai_answer_generation: bool = False
|
skip_gen_ai_answer_generation: bool = False
|
||||||
|
|
||||||
|
|
||||||
class GraphConfig(BaseModel):
|
class GraphConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Main configuration class that combines all config components for Langgraph execution
|
Main container for data needed for Langgraph execution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
inputs: GraphInputs
|
inputs: GraphInputs
|
||||||
tooling: GraphTooling
|
tooling: GraphTooling
|
||||||
persistence: GraphPersistence
|
behavior: GraphSearchConfig
|
||||||
behavior: SearchBehaviorConfig
|
# Only needed for agentic search
|
||||||
|
persistence: GraphPersistence | None = None
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_search_tool(self) -> "GraphConfig":
|
def validate_search_tool(self) -> "GraphConfig":
|
||||||
|
@ -7,7 +7,7 @@ from langgraph.types import StreamWriter
|
|||||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||||
from onyx.agents.agent_search.basic.states import BasicState
|
from onyx.agents.agent_search.basic.states import BasicState
|
||||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.chat.models import LlmDoc
|
from onyx.chat.models import LlmDoc
|
||||||
from onyx.tools.tool_implementations.search.search_tool import (
|
from onyx.tools.tool_implementations.search.search_tool import (
|
||||||
SEARCH_DOC_CONTENT_ID,
|
SEARCH_DOC_CONTENT_ID,
|
||||||
@ -23,14 +23,14 @@ logger = setup_logger()
|
|||||||
def basic_use_tool_response(
|
def basic_use_tool_response(
|
||||||
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||||
) -> BasicOutput:
|
) -> BasicOutput:
|
||||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
structured_response_format = agent_config.structured_response_format
|
structured_response_format = agent_config.inputs.structured_response_format
|
||||||
llm = agent_config.primary_llm
|
llm = agent_config.tooling.primary_llm
|
||||||
tool_choice = state.tool_choice
|
tool_choice = state.tool_choice
|
||||||
if tool_choice is None:
|
if tool_choice is None:
|
||||||
raise ValueError("Tool choice is None")
|
raise ValueError("Tool choice is None")
|
||||||
tool = tool_choice.tool
|
tool = tool_choice.tool
|
||||||
prompt_builder = agent_config.prompt_builder
|
prompt_builder = agent_config.inputs.prompt_builder
|
||||||
if state.tool_call_output is None:
|
if state.tool_call_output is None:
|
||||||
raise ValueError("Tool call output is None")
|
raise ValueError("Tool call output is None")
|
||||||
tool_call_output = state.tool_call_output
|
tool_call_output = state.tool_call_output
|
||||||
@ -41,7 +41,7 @@ def basic_use_tool_response(
|
|||||||
prompt_builder=prompt_builder,
|
prompt_builder=prompt_builder,
|
||||||
tool_call_summary=tool_call_summary,
|
tool_call_summary=tool_call_summary,
|
||||||
tool_responses=tool_call_responses,
|
tool_responses=tool_call_responses,
|
||||||
using_tool_calling_llm=agent_config.using_tool_calling_llm,
|
using_tool_calling_llm=agent_config.tooling.using_tool_calling_llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_search_results = []
|
final_search_results = []
|
||||||
@ -58,7 +58,7 @@ def basic_use_tool_response(
|
|||||||
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
||||||
|
|
||||||
new_tool_call_chunk = AIMessageChunk(content="")
|
new_tool_call_chunk = AIMessageChunk(content="")
|
||||||
if not agent_config.skip_gen_ai_answer_generation:
|
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||||
stream = llm.stream(
|
stream = llm.stream(
|
||||||
prompt=new_prompt_builder.build(),
|
prompt=new_prompt_builder.build(),
|
||||||
structured_response_format=structured_response_format,
|
structured_response_format=structured_response_format,
|
||||||
|
@ -6,7 +6,7 @@ from langchain_core.runnables.config import RunnableConfig
|
|||||||
from langgraph.types import StreamWriter
|
from langgraph.types import StreamWriter
|
||||||
|
|
||||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||||
@ -36,16 +36,18 @@ def llm_tool_choice(
|
|||||||
"""
|
"""
|
||||||
should_stream_answer = state.should_stream_answer
|
should_stream_answer = state.should_stream_answer
|
||||||
|
|
||||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
using_tool_calling_llm = agent_config.using_tool_calling_llm
|
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||||
prompt_builder = state.prompt_snapshot or agent_config.prompt_builder
|
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||||
|
|
||||||
llm = agent_config.primary_llm
|
llm = agent_config.tooling.primary_llm
|
||||||
skip_gen_ai_answer_generation = agent_config.skip_gen_ai_answer_generation
|
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
||||||
|
|
||||||
structured_response_format = agent_config.structured_response_format
|
structured_response_format = agent_config.inputs.structured_response_format
|
||||||
tools = [tool for tool in (agent_config.tools or []) if tool.name in state.tools]
|
tools = [
|
||||||
force_use_tool = agent_config.force_use_tool
|
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||||
|
]
|
||||||
|
force_use_tool = agent_config.tooling.force_use_tool
|
||||||
|
|
||||||
tool, tool_args = None, None
|
tool, tool_args = None, None
|
||||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||||
@ -103,7 +105,8 @@ def llm_tool_choice(
|
|||||||
|
|
||||||
tool_message = process_llm_stream(
|
tool_message = process_llm_stream(
|
||||||
stream,
|
stream,
|
||||||
should_stream_answer and not agent_config.skip_gen_ai_answer_generation,
|
should_stream_answer
|
||||||
|
and not agent_config.behavior.skip_gen_ai_answer_generation,
|
||||||
writer,
|
writer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3,15 +3,15 @@ from typing import cast
|
|||||||
|
|
||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
|
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||||
|
|
||||||
|
|
||||||
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
||||||
agent_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
return ToolChoiceInput(
|
return ToolChoiceInput(
|
||||||
# NOTE: this node is used at the top level of the agent, so we always stream
|
# NOTE: this node is used at the top level of the agent, so we always stream
|
||||||
should_stream_answer=True,
|
should_stream_answer=True,
|
||||||
prompt_snapshot=None, # uses default prompt builder
|
prompt_snapshot=None, # uses default prompt builder
|
||||||
tools=[tool.name for tool in (agent_config.tools or [])],
|
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,7 @@ from langchain_core.messages.tool import ToolCall
|
|||||||
from langchain_core.runnables.config import RunnableConfig
|
from langchain_core.runnables.config import RunnableConfig
|
||||||
from langgraph.types import StreamWriter
|
from langgraph.types import StreamWriter
|
||||||
|
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||||
@ -31,7 +31,7 @@ def tool_call(
|
|||||||
) -> ToolCallUpdate:
|
) -> ToolCallUpdate:
|
||||||
"""Calls the tool specified in the state and updates the state with the result"""
|
"""Calls the tool specified in the state and updates the state with the result"""
|
||||||
|
|
||||||
cast(AgentSearchConfig, config["metadata"]["config"])
|
cast(GraphConfig, config["metadata"]["config"])
|
||||||
|
|
||||||
tool_choice = state.tool_choice
|
tool_choice = state.tool_choice
|
||||||
if tool_choice is None:
|
if tool_choice is None:
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
from collections.abc import AsyncIterable
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import cast
|
from typing import cast
|
||||||
@ -16,7 +14,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
|||||||
from onyx.agents.agent_search.deep_search.main.states import (
|
from onyx.agents.agent_search.deep_search.main.states import (
|
||||||
MainInput as MainInput_a,
|
MainInput as MainInput_a,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||||
from onyx.chat.models import AgentAnswerPiece
|
from onyx.chat.models import AgentAnswerPiece
|
||||||
from onyx.chat.models import AnswerPacket
|
from onyx.chat.models import AnswerPacket
|
||||||
@ -39,14 +37,6 @@ logger = setup_logger()
|
|||||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||||
|
|
||||||
|
|
||||||
def _set_combined_token_value(
|
|
||||||
combined_token: str, parsed_object: AgentAnswerPiece
|
|
||||||
) -> AgentAnswerPiece:
|
|
||||||
parsed_object.answer_piece = combined_token
|
|
||||||
|
|
||||||
return parsed_object
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_agent_event(
|
def _parse_agent_event(
|
||||||
event: StreamEvent,
|
event: StreamEvent,
|
||||||
) -> AnswerPacket | None:
|
) -> AnswerPacket | None:
|
||||||
@ -84,63 +74,12 @@ def _parse_agent_event(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/60226557/how-to-forcefully-close-an-async-generator
|
|
||||||
# https://stackoverflow.com/questions/40897428/please-explain-task-was-destroyed-but-it-is-pending-after-cancelling-tasks
|
|
||||||
task_references: set[asyncio.Task[StreamEvent]] = set()
|
|
||||||
|
|
||||||
|
|
||||||
def _manage_async_event_streaming(
|
|
||||||
compiled_graph: CompiledStateGraph,
|
|
||||||
config: AgentSearchConfig | None,
|
|
||||||
graph_input: MainInput_a | BasicInput,
|
|
||||||
) -> Iterable[StreamEvent]:
|
|
||||||
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
|
|
||||||
message_id = config.message_id if config else None
|
|
||||||
async for event in compiled_graph.astream_events(
|
|
||||||
input=graph_input,
|
|
||||||
config={"metadata": {"config": config, "thread_id": str(message_id)}},
|
|
||||||
# debug=True,
|
|
||||||
# indicating v2 here deserves further scrutiny
|
|
||||||
version="v2",
|
|
||||||
):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
# This might be able to be simplified
|
|
||||||
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
# Get the async generator
|
|
||||||
async_gen = _run_async_event_stream()
|
|
||||||
# Convert to AsyncIterator
|
|
||||||
async_iter = async_gen.__aiter__()
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# Create a coroutine by calling anext with the async iterator
|
|
||||||
next_coro = anext(async_iter)
|
|
||||||
task = asyncio.ensure_future(next_coro, loop=loop)
|
|
||||||
task_references.add(task)
|
|
||||||
# Run the coroutine to get the next event
|
|
||||||
event = loop.run_until_complete(task)
|
|
||||||
yield event
|
|
||||||
except (StopAsyncIteration, GeneratorExit):
|
|
||||||
break
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
for task in task_references.pop():
|
|
||||||
task.cancel()
|
|
||||||
except StopAsyncIteration:
|
|
||||||
pass
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
return _yield_async_to_sync()
|
|
||||||
|
|
||||||
|
|
||||||
def manage_sync_streaming(
|
def manage_sync_streaming(
|
||||||
compiled_graph: CompiledStateGraph,
|
compiled_graph: CompiledStateGraph,
|
||||||
config: AgentSearchConfig,
|
config: GraphConfig,
|
||||||
graph_input: BasicInput | MainInput_a,
|
graph_input: BasicInput | MainInput_a,
|
||||||
) -> Iterable[StreamEvent]:
|
) -> Iterable[StreamEvent]:
|
||||||
message_id = config.message_id if config else None
|
message_id = config.persistence.message_id if config.persistence else None
|
||||||
for event in compiled_graph.stream(
|
for event in compiled_graph.stream(
|
||||||
stream_mode="custom",
|
stream_mode="custom",
|
||||||
input=graph_input,
|
input=graph_input,
|
||||||
@ -151,11 +90,13 @@ def manage_sync_streaming(
|
|||||||
|
|
||||||
def run_graph(
|
def run_graph(
|
||||||
compiled_graph: CompiledStateGraph,
|
compiled_graph: CompiledStateGraph,
|
||||||
config: AgentSearchConfig,
|
config: GraphConfig,
|
||||||
input: BasicInput | MainInput_a,
|
input: BasicInput | MainInput_a,
|
||||||
) -> AnswerStream:
|
) -> AnswerStream:
|
||||||
config.perform_initial_search_decomposition = INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
config.behavior.perform_initial_search_decomposition = (
|
||||||
config.allow_refinement = ALLOW_REFINEMENT
|
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||||
|
)
|
||||||
|
config.behavior.allow_refinement = ALLOW_REFINEMENT
|
||||||
|
|
||||||
for event in manage_sync_streaming(
|
for event in manage_sync_streaming(
|
||||||
compiled_graph=compiled_graph, config=config, graph_input=input
|
compiled_graph=compiled_graph, config=config, graph_input=input
|
||||||
@ -177,21 +118,24 @@ def load_compiled_graph() -> CompiledStateGraph:
|
|||||||
|
|
||||||
|
|
||||||
def run_main_graph(
|
def run_main_graph(
|
||||||
config: AgentSearchConfig,
|
config: GraphConfig,
|
||||||
) -> AnswerStream:
|
) -> AnswerStream:
|
||||||
compiled_graph = load_compiled_graph()
|
compiled_graph = load_compiled_graph()
|
||||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
|
||||||
|
input = MainInput_a(
|
||||||
|
base_question=config.inputs.search_request.query, log_messages=[]
|
||||||
|
)
|
||||||
|
|
||||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||||
yield ToolCallKickoff(
|
yield ToolCallKickoff(
|
||||||
tool_name="agent_search_0",
|
tool_name="agent_search_0",
|
||||||
tool_args={"query": config.search_request.query},
|
tool_args={"query": config.inputs.search_request.query},
|
||||||
)
|
)
|
||||||
yield from run_graph(compiled_graph, config, input)
|
yield from run_graph(compiled_graph, config, input)
|
||||||
|
|
||||||
|
|
||||||
def run_basic_graph(
|
def run_basic_graph(
|
||||||
config: AgentSearchConfig,
|
config: GraphConfig,
|
||||||
) -> AnswerStream:
|
) -> AnswerStream:
|
||||||
graph = basic_graph_builder()
|
graph = basic_graph_builder()
|
||||||
compiled_graph = graph.compile()
|
compiled_graph = graph.compile()
|
||||||
@ -222,15 +166,16 @@ if __name__ == "__main__":
|
|||||||
# Joachim custom persona
|
# Joachim custom persona
|
||||||
|
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
config, search_tool = get_test_config(
|
config = get_test_config(db_session, primary_llm, fast_llm, search_request)
|
||||||
db_session, primary_llm, fast_llm, search_request
|
assert (
|
||||||
)
|
config.persistence is not None
|
||||||
|
), "set a chat session id to run this test"
|
||||||
|
|
||||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||||
config.use_agentic_persistence = True
|
|
||||||
# config.perform_initial_search_path_decision = False
|
# config.perform_initial_search_path_decision = False
|
||||||
config.perform_initial_search_decomposition = True
|
config.behavior.perform_initial_search_decomposition = True
|
||||||
input = MainInput_a(
|
input = MainInput_a(
|
||||||
base_question=config.search_request.query, log_messages=[]
|
base_question=config.inputs.search_request.query, log_messages=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
# with open("output.txt", "w") as f:
|
# with open("output.txt", "w") as f:
|
||||||
|
@ -4,7 +4,7 @@ from langchain.schema import SystemMessage
|
|||||||
from langchain_core.messages.tool import ToolMessage
|
from langchain_core.messages.tool import ToolMessage
|
||||||
|
|
||||||
from onyx.agents.agent_search.models import AgentPromptEnrichmentComponents
|
from onyx.agents.agent_search.models import AgentPromptEnrichmentComponents
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
@ -80,11 +80,11 @@ def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
|
def build_history_prompt(config: GraphConfig, question: str) -> str:
|
||||||
prompt_builder = config.prompt_builder
|
prompt_builder = config.inputs.prompt_builder
|
||||||
model = config.fast_llm
|
model = config.tooling.fast_llm
|
||||||
persona_base = get_persona_agent_prompt_expressions(
|
persona_base = get_persona_agent_prompt_expressions(
|
||||||
config.search_request.persona
|
config.inputs.search_request.persona
|
||||||
).base_prompt
|
).base_prompt
|
||||||
|
|
||||||
if prompt_builder is None:
|
if prompt_builder is None:
|
||||||
@ -118,13 +118,13 @@ def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_prompt_enrichment_components(
|
def get_prompt_enrichment_components(
|
||||||
config: AgentSearchConfig,
|
config: GraphConfig,
|
||||||
) -> AgentPromptEnrichmentComponents:
|
) -> AgentPromptEnrichmentComponents:
|
||||||
persona_prompts = get_persona_agent_prompt_expressions(
|
persona_prompts = get_persona_agent_prompt_expressions(
|
||||||
config.search_request.persona
|
config.inputs.search_request.persona
|
||||||
)
|
)
|
||||||
|
|
||||||
history = build_history_prompt(config, config.search_request.query)
|
history = build_history_prompt(config, config.inputs.search_request.query)
|
||||||
|
|
||||||
date_str = get_today_prompt()
|
date_str = get_today_prompt()
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import ast
|
import ast
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
@ -17,7 +18,11 @@ from langchain_core.messages import HumanMessage
|
|||||||
from langgraph.types import StreamWriter
|
from langgraph.types import StreamWriter
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
|
from onyx.agents.agent_search.models import GraphInputs
|
||||||
|
from onyx.agents.agent_search.models import GraphPersistence
|
||||||
|
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||||
|
from onyx.agents.agent_search.models import GraphTooling
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||||
EntityRelationshipTermExtraction,
|
EntityRelationshipTermExtraction,
|
||||||
)
|
)
|
||||||
@ -60,6 +65,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
|||||||
)
|
)
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
|
from onyx.tools.utils import explicit_tool_calling_supported
|
||||||
|
|
||||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||||
|
|
||||||
@ -166,7 +172,7 @@ def get_test_config(
|
|||||||
fast_llm: LLM,
|
fast_llm: LLM,
|
||||||
search_request: SearchRequest,
|
search_request: SearchRequest,
|
||||||
use_agentic_search: bool = True,
|
use_agentic_search: bool = True,
|
||||||
) -> tuple[AgentSearchConfig, SearchTool]:
|
) -> GraphConfig:
|
||||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||||
document_pruning_config = DocumentPruningConfig(
|
document_pruning_config = DocumentPruningConfig(
|
||||||
max_chunks=int(
|
max_chunks=int(
|
||||||
@ -221,12 +227,8 @@ def get_test_config(
|
|||||||
bypass_acl=search_tool_config.bypass_acl,
|
bypass_acl=search_tool_config.bypass_acl,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = AgentSearchConfig(
|
graph_inputs = GraphInputs(
|
||||||
search_request=search_request,
|
search_request=search_request,
|
||||||
primary_llm=primary_llm,
|
|
||||||
fast_llm=fast_llm,
|
|
||||||
search_tool=search_tool,
|
|
||||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
|
||||||
prompt_builder=AnswerPromptBuilder(
|
prompt_builder=AnswerPromptBuilder(
|
||||||
user_message=HumanMessage(content=search_request.query),
|
user_message=HumanMessage(content=search_request.query),
|
||||||
message_history=[],
|
message_history=[],
|
||||||
@ -234,17 +236,42 @@ def get_test_config(
|
|||||||
raw_user_query=search_request.query,
|
raw_user_query=search_request.query,
|
||||||
raw_user_uploaded_files=[],
|
raw_user_uploaded_files=[],
|
||||||
),
|
),
|
||||||
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
structured_response_format=answer_style_config.structured_response_format,
|
||||||
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
|
|
||||||
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
|
|
||||||
message_id=1,
|
|
||||||
use_agentic_persistence=True,
|
|
||||||
db_session=db_session,
|
|
||||||
tools=[search_tool],
|
|
||||||
use_agentic_search=use_agentic_search,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return config, search_tool
|
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||||
|
primary_llm.config.model_provider, primary_llm.config.model_name
|
||||||
|
)
|
||||||
|
graph_tooling = GraphTooling(
|
||||||
|
primary_llm=primary_llm,
|
||||||
|
fast_llm=fast_llm,
|
||||||
|
search_tool=search_tool,
|
||||||
|
tools=[search_tool],
|
||||||
|
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||||
|
using_tool_calling_llm=using_tool_calling_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_persistence = None
|
||||||
|
if chat_session_id := os.environ.get("ONYX_AS_CHAT_SESSION_ID"):
|
||||||
|
graph_persistence = GraphPersistence(
|
||||||
|
db_session=db_session,
|
||||||
|
chat_session_id=UUID(chat_session_id),
|
||||||
|
message_id=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
search_behavior_config = GraphSearchConfig(
|
||||||
|
use_agentic_search=use_agentic_search,
|
||||||
|
skip_gen_ai_answer_generation=False,
|
||||||
|
allow_refinement=True,
|
||||||
|
)
|
||||||
|
graph_config = GraphConfig(
|
||||||
|
inputs=graph_inputs,
|
||||||
|
tooling=graph_tooling,
|
||||||
|
persistence=graph_persistence,
|
||||||
|
behavior=search_behavior_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return graph_config
|
||||||
|
|
||||||
|
|
||||||
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:
|
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:
|
||||||
|
@ -4,7 +4,11 @@ from uuid import UUID
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
|
from onyx.agents.agent_search.models import GraphInputs
|
||||||
|
from onyx.agents.agent_search.models import GraphPersistence
|
||||||
|
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||||
|
from onyx.agents.agent_search.models import GraphTooling
|
||||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||||
from onyx.chat.models import AgentAnswerPiece
|
from onyx.chat.models import AgentAnswerPiece
|
||||||
@ -16,12 +20,10 @@ from onyx.chat.models import OnyxAnswerPiece
|
|||||||
from onyx.chat.models import StreamStopInfo
|
from onyx.chat.models import StreamStopInfo
|
||||||
from onyx.chat.models import StreamStopReason
|
from onyx.chat.models import StreamStopReason
|
||||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
|
||||||
from onyx.configs.constants import BASIC_KEY
|
from onyx.configs.constants import BASIC_KEY
|
||||||
from onyx.context.search.models import SearchRequest
|
from onyx.context.search.models import SearchRequest
|
||||||
from onyx.file_store.utils import InMemoryChatFile
|
from onyx.file_store.utils import InMemoryChatFile
|
||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.natural_language_processing.utils import get_tokenizer
|
|
||||||
from onyx.tools.force import ForceUseTool
|
from onyx.tools.force import ForceUseTool
|
||||||
from onyx.tools.tool import Tool
|
from onyx.tools.tool import Tool
|
||||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
@ -57,35 +59,9 @@ class Answer:
|
|||||||
use_agentic_persistence: bool = True,
|
use_agentic_persistence: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.is_connected: Callable[[], bool] | None = is_connected
|
self.is_connected: Callable[[], bool] | None = is_connected
|
||||||
|
|
||||||
self.latest_query_files = latest_query_files or []
|
|
||||||
|
|
||||||
self.tools = tools or []
|
|
||||||
self.force_use_tool = force_use_tool
|
|
||||||
# used for QA flow where we only want to send a single message
|
|
||||||
|
|
||||||
self.answer_style_config = answer_style_config
|
|
||||||
|
|
||||||
self.llm = llm
|
|
||||||
self.fast_llm = fast_llm
|
|
||||||
self.llm_tokenizer = get_tokenizer(
|
|
||||||
provider_type=llm.config.model_provider,
|
|
||||||
model_name=llm.config.model_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._streamed_output: list[str] | None = None
|
|
||||||
self._processed_stream: (list[AnswerPacket] | None) = None
|
self._processed_stream: (list[AnswerPacket] | None) = None
|
||||||
|
|
||||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
|
||||||
self._is_cancelled = False
|
self._is_cancelled = False
|
||||||
|
|
||||||
self.using_tool_calling_llm = (
|
|
||||||
explicit_tool_calling_supported(
|
|
||||||
self.llm.config.model_provider, self.llm.config.model_name
|
|
||||||
)
|
|
||||||
and not skip_explicit_tool_calling
|
|
||||||
)
|
|
||||||
|
|
||||||
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
|
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
|
||||||
search_tool: SearchTool | None = None
|
search_tool: SearchTool | None = None
|
||||||
|
|
||||||
@ -95,43 +71,46 @@ class Answer:
|
|||||||
elif len(search_tools) == 1:
|
elif len(search_tools) == 1:
|
||||||
search_tool = search_tools[0]
|
search_tool = search_tools[0]
|
||||||
|
|
||||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
using_tool_calling_llm = (
|
||||||
|
explicit_tool_calling_supported(
|
||||||
llm.config.model_provider, llm.config.model_name
|
llm.config.model_provider, llm.config.model_name
|
||||||
)
|
)
|
||||||
self.agent_search_config = AgentSearchConfig(
|
and not skip_explicit_tool_calling
|
||||||
|
)
|
||||||
|
|
||||||
|
self.graph_inputs = GraphInputs(
|
||||||
search_request=search_request,
|
search_request=search_request,
|
||||||
|
prompt_builder=prompt_builder,
|
||||||
|
files=latest_query_files,
|
||||||
|
structured_response_format=answer_style_config.structured_response_format,
|
||||||
|
)
|
||||||
|
self.graph_tooling = GraphTooling(
|
||||||
primary_llm=llm,
|
primary_llm=llm,
|
||||||
fast_llm=fast_llm,
|
fast_llm=fast_llm,
|
||||||
search_tool=search_tool,
|
search_tool=search_tool,
|
||||||
|
tools=tools or [],
|
||||||
force_use_tool=force_use_tool,
|
force_use_tool=force_use_tool,
|
||||||
use_agentic_search=use_agentic_search,
|
using_tool_calling_llm=using_tool_calling_llm,
|
||||||
|
)
|
||||||
|
self.graph_persistence = None
|
||||||
|
if use_agentic_persistence:
|
||||||
|
assert db_session, "db_session must be provided for agentic persistence"
|
||||||
|
self.graph_persistence = GraphPersistence(
|
||||||
|
db_session=db_session,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
message_id=current_agent_message_id,
|
message_id=current_agent_message_id,
|
||||||
use_agentic_persistence=use_agentic_persistence,
|
)
|
||||||
allow_refinement=True,
|
self.search_behavior_config = GraphSearchConfig(
|
||||||
db_session=db_session,
|
use_agentic_search=use_agentic_search,
|
||||||
prompt_builder=prompt_builder,
|
|
||||||
tools=tools,
|
|
||||||
using_tool_calling_llm=using_tool_calling_llm,
|
|
||||||
files=latest_query_files,
|
|
||||||
structured_response_format=answer_style_config.structured_response_format,
|
|
||||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||||
|
allow_refinement=True,
|
||||||
)
|
)
|
||||||
self.db_session = db_session
|
self.graph_config = GraphConfig(
|
||||||
|
inputs=self.graph_inputs,
|
||||||
def _get_tools_list(self) -> list[Tool]:
|
tooling=self.graph_tooling,
|
||||||
if not self.force_use_tool.force_use:
|
persistence=self.graph_persistence,
|
||||||
return self.tools
|
behavior=self.search_behavior_config,
|
||||||
|
|
||||||
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
|
|
||||||
|
|
||||||
args_str = (
|
|
||||||
f" with args='{self.force_use_tool.args}'"
|
|
||||||
if self.force_use_tool.args
|
|
||||||
else ""
|
|
||||||
)
|
)
|
||||||
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
|
|
||||||
return [tool]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def processed_streamed_output(self) -> AnswerStream:
|
def processed_streamed_output(self) -> AnswerStream:
|
||||||
@ -141,11 +120,11 @@ class Answer:
|
|||||||
|
|
||||||
run_langgraph = (
|
run_langgraph = (
|
||||||
run_main_graph
|
run_main_graph
|
||||||
if self.agent_search_config.use_agentic_search
|
if self.graph_config.behavior.use_agentic_search
|
||||||
else run_basic_graph
|
else run_basic_graph
|
||||||
)
|
)
|
||||||
stream = run_langgraph(
|
stream = run_langgraph(
|
||||||
self.agent_search_config,
|
self.graph_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_stream = []
|
processed_stream = []
|
||||||
|
@ -5,10 +5,10 @@ import os
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from onyx.agents.agent_search.deep_search.main__graph.graph_builder import (
|
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||||
main_graph_builder,
|
main_graph_builder,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search.main__graph.states import MainInput
|
from onyx.agents.agent_search.deep_search.main.states import MainInput
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||||
from onyx.context.search.models import SearchRequest
|
from onyx.context.search.models import SearchRequest
|
||||||
from onyx.db.engine import get_session_context_manager
|
from onyx.db.engine import get_session_context_manager
|
||||||
|
@ -70,13 +70,13 @@ def answer_instance(
|
|||||||
|
|
||||||
|
|
||||||
def test_basic_answer(answer_instance: Answer) -> None:
|
def test_basic_answer(answer_instance: Answer) -> None:
|
||||||
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
|
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
|
||||||
mock_llm.stream.return_value = [
|
mock_llm.stream.return_value = [
|
||||||
AIMessageChunk(content="This is a "),
|
AIMessageChunk(content="This is a "),
|
||||||
AIMessageChunk(content="mock answer."),
|
AIMessageChunk(content="mock answer."),
|
||||||
]
|
]
|
||||||
answer_instance.agent_search_config.fast_llm = mock_llm
|
answer_instance.graph_config.tooling.fast_llm = mock_llm
|
||||||
answer_instance.agent_search_config.primary_llm = mock_llm
|
answer_instance.graph_config.tooling.primary_llm = mock_llm
|
||||||
|
|
||||||
output = list(answer_instance.processed_streamed_output)
|
output = list(answer_instance.processed_streamed_output)
|
||||||
assert len(output) == 2
|
assert len(output) == 2
|
||||||
@ -128,11 +128,11 @@ def test_answer_with_search_call(
|
|||||||
force_use_tool: ForceUseTool,
|
force_use_tool: ForceUseTool,
|
||||||
expected_tool_args: dict,
|
expected_tool_args: dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
answer_instance.agent_search_config.tools = [mock_search_tool]
|
answer_instance.graph_config.tooling.tools = [mock_search_tool]
|
||||||
answer_instance.agent_search_config.force_use_tool = force_use_tool
|
answer_instance.graph_config.tooling.force_use_tool = force_use_tool
|
||||||
|
|
||||||
# Set up the LLM mock to return search results and then an answer
|
# Set up the LLM mock to return search results and then an answer
|
||||||
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
|
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
|
||||||
|
|
||||||
stream_side_effect: list[list[BaseMessage]] = []
|
stream_side_effect: list[list[BaseMessage]] = []
|
||||||
|
|
||||||
@ -253,10 +253,10 @@ def test_answer_with_search_no_tool_calling(
|
|||||||
mock_contexts: OnyxContexts,
|
mock_contexts: OnyxContexts,
|
||||||
mock_search_tool: MagicMock,
|
mock_search_tool: MagicMock,
|
||||||
) -> None:
|
) -> None:
|
||||||
answer_instance.agent_search_config.tools = [mock_search_tool]
|
answer_instance.graph_config.tooling.tools = [mock_search_tool]
|
||||||
|
|
||||||
# Set up the LLM mock to return an answer
|
# Set up the LLM mock to return an answer
|
||||||
mock_llm = cast(Mock, answer_instance.agent_search_config.primary_llm)
|
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
|
||||||
mock_llm.stream.return_value = [
|
mock_llm.stream.return_value = [
|
||||||
AIMessageChunk(content="Based on the search results, "),
|
AIMessageChunk(content="Based on the search results, "),
|
||||||
AIMessageChunk(content="the answer is abc[1]. "),
|
AIMessageChunk(content="the answer is abc[1]. "),
|
||||||
@ -264,7 +264,7 @@ def test_answer_with_search_no_tool_calling(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Force non-tool calling behavior
|
# Force non-tool calling behavior
|
||||||
answer_instance.agent_search_config.using_tool_calling_llm = False
|
answer_instance.graph_config.tooling.using_tool_calling_llm = False
|
||||||
|
|
||||||
# Process the output
|
# Process the output
|
||||||
output = list(answer_instance.processed_streamed_output)
|
output = list(answer_instance.processed_streamed_output)
|
||||||
@ -319,7 +319,7 @@ def test_answer_with_search_no_tool_calling(
|
|||||||
|
|
||||||
# Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool
|
# Verify that get_args_for_non_tool_calling_llm was called on the mock_search_tool
|
||||||
mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with(
|
mock_search_tool.get_args_for_non_tool_calling_llm.assert_called_once_with(
|
||||||
QUERY, [], answer_instance.llm
|
QUERY, [], answer_instance.graph_config.tooling.primary_llm
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify that the search tool's run method was called
|
# Verify that the search tool's run method was called
|
||||||
@ -329,8 +329,8 @@ def test_answer_with_search_no_tool_calling(
|
|||||||
def test_is_cancelled(answer_instance: Answer) -> None:
|
def test_is_cancelled(answer_instance: Answer) -> None:
|
||||||
# Set up the LLM mock to return multiple chunks
|
# Set up the LLM mock to return multiple chunks
|
||||||
mock_llm = Mock()
|
mock_llm = Mock()
|
||||||
answer_instance.agent_search_config.primary_llm = mock_llm
|
answer_instance.graph_config.tooling.primary_llm = mock_llm
|
||||||
answer_instance.agent_search_config.fast_llm = mock_llm
|
answer_instance.graph_config.tooling.fast_llm = mock_llm
|
||||||
mock_llm.stream.return_value = [
|
mock_llm.stream.return_value = [
|
||||||
AIMessageChunk(content="This is the "),
|
AIMessageChunk(content="This is the "),
|
||||||
AIMessageChunk(content="first part."),
|
AIMessageChunk(content="first part."),
|
||||||
|
Reference in New Issue
Block a user