allowed empty Search Tool for non-agentic search

This commit is contained in:
Evan Lohn
2025-01-25 16:06:11 -08:00
parent 139374966f
commit 7d494cd65e
10 changed files with 37 additions and 10 deletions

View File

@@ -33,6 +33,8 @@ def doc_reranking(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question if state.question else agent_a_config.search_request.query question = state.question if state.question else agent_a_config.search_request.query
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
with get_session_context_manager() as db_session: with get_session_context_manager() as db_session:
_search_query = retrieval_preprocessing( _search_query = retrieval_preprocessing(
search_request=SearchRequest(query=question), search_request=SearchRequest(query=question),

View File

@@ -54,6 +54,8 @@ def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrieval
) )
query_info = None query_info = None
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
# new db session to avoid concurrency issues # new db session to avoid concurrency issues
with get_session_context_manager() as db_session: with get_session_context_manager() as db_session:
for tool_response in search_tool.run( for tool_response in search_tool.run(

View File

@@ -45,6 +45,8 @@ def format_results(
# the top 3 for that one. We may want to revisit this. # the top 3 for that one. We may want to revisit this.
stream_documents = state.expanded_retrieval_results[-1].search_results[:3] stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
for tool_response in yield_search_responses( for tool_response in yield_search_responses(
query=state.question, query=state.question,
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.) reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)

View File

@@ -30,6 +30,7 @@ def route_initial_tool_choice(
if state.tool_choice is not None: if state.tool_choice is not None:
if ( if (
agent_config.use_agentic_search agent_config.use_agentic_search
and agent_config.search_tool is not None
and state.tool_choice.tool.name == agent_config.search_tool.name and state.tool_choice.tool.name == agent_config.search_tool.name
): ):
return "agent_search_start" return "agent_search_start"

View File

@@ -61,6 +61,7 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
persona_id = agent_a_config.search_request.persona.id persona_id = agent_a_config.search_request.persona.id
user_id = None user_id = None
if agent_a_config.search_tool is not None:
user = agent_a_config.search_tool.user user = agent_a_config.search_tool.user
if user: if user:
user_id = user.id user_id = user.id

View File

@@ -32,6 +32,8 @@ def agent_search_start(
# Initial search to inform decomposition. Just get top 3 fits # Initial search to inform decomposition. Just get top 3 fits
search_tool = agent_a_config.search_tool search_tool = agent_a_config.search_tool
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question) retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS] exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]

View File

@@ -99,7 +99,8 @@ def generate_initial_answer(
else: else:
# Use the query info from the base document retrieval # Use the query info from the base document retrieval
query_info = get_query_info(state.original_question_retrieval_results) query_info = get_query_info(state.original_question_retrieval_results)
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
for tool_response in yield_search_responses( for tool_response in yield_search_responses(
query=question, query=question,
reranked_sections=relevant_docs, reranked_sections=relevant_docs,

View File

@@ -71,6 +71,8 @@ def generate_refined_answer(
combined_documents = dedup_inference_sections(initial_documents, revised_documents) combined_documents = dedup_inference_sections(initial_documents, revised_documents)
query_info = get_query_info(state.original_question_retrieval_results) query_info = get_query_info(state.original_question_retrieval_results)
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
# stream refined answer docs # stream refined answer docs
for tool_response in yield_search_responses( for tool_response in yield_search_responses(
query=question, query=question,

View File

@@ -25,7 +25,6 @@ class AgentSearchConfig:
primary_llm: LLM primary_llm: LLM
fast_llm: LLM fast_llm: LLM
search_tool: SearchTool
# Whether to force use of a tool, or to # Whether to force use of a tool, or to
# force tool args IF the tool is used # force tool args IF the tool is used
@@ -37,6 +36,8 @@ class AgentSearchConfig:
# single_message_history: str | None = None # single_message_history: str | None = None
prompt_builder: AnswerPromptBuilder prompt_builder: AnswerPromptBuilder
search_tool: SearchTool | None = None
use_agentic_search: bool = False use_agentic_search: bool = False
# For persisting agent search data # For persisting agent search data
@@ -79,6 +80,12 @@ class AgentSearchConfig:
) )
return self return self
@model_validator(mode="after")
def validate_search_tool(self) -> "AgentSearchConfig":
if self.use_agentic_search and self.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
return self
class AgentDocumentCitations(BaseModel): class AgentDocumentCitations(BaseModel):
document_id: str document_id: str

View File

@@ -743,12 +743,19 @@ def stream_chat_message_objects(
# we should construct this unconditionally inside Answer instead # we should construct this unconditionally inside Answer instead
# Leaving it here for the time being to avoid breaking changes # Leaving it here for the time being to avoid breaking changes
search_tools = [tool for tool in tools if isinstance(tool, SearchTool)] search_tools = [tool for tool in tools if isinstance(tool, SearchTool)]
if len(search_tools) == 0: search_tool: SearchTool | None = None
raise ValueError("No search tool found")
elif len(search_tools) > 1: if len(search_tools) > 1:
# TODO: handle multiple search tools # TODO: handle multiple search tools
raise ValueError("Multiple search tools found") logger.warning("Multiple search tools found, using first one")
search_tool = search_tools[0] search_tool = search_tools[0]
elif len(search_tools) == 1:
search_tool = search_tools[0]
else:
logger.warning("No search tool found")
if new_msg_req.use_agentic_search:
raise ValueError("No search tool found, cannot use agentic search")
using_tool_calling_llm = explicit_tool_calling_supported( using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name llm.config.model_provider, llm.config.model_name
) )