allowed empty Search Tool for non-agentic search

This commit is contained in:
Evan Lohn 2025-01-25 16:06:11 -08:00
parent 139374966f
commit 7d494cd65e
10 changed files with 37 additions and 10 deletions

View File

@ -33,6 +33,8 @@ def doc_reranking(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question if state.question else agent_a_config.search_request.query
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
with get_session_context_manager() as db_session:
_search_query = retrieval_preprocessing(
search_request=SearchRequest(query=question),

View File

@ -54,6 +54,8 @@ def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrieval
)
query_info = None
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(

View File

@ -45,6 +45,8 @@ def format_results(
# the top 3 for that one. We may want to revisit this.
stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)

View File

@ -30,6 +30,7 @@ def route_initial_tool_choice(
if state.tool_choice is not None:
if (
agent_config.use_agentic_search
and agent_config.search_tool is not None
and state.tool_choice.tool.name == agent_config.search_tool.name
):
return "agent_search_start"

View File

@ -61,9 +61,10 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
persona_id = agent_a_config.search_request.persona.id
user_id = None
user = agent_a_config.search_tool.user
if user:
user_id = user.id
if agent_a_config.search_tool is not None:
user = agent_a_config.search_tool.user
if user:
user_id = user.id
# log the agent metrics
if agent_a_config.db_session is not None:

View File

@ -32,6 +32,8 @@ def agent_search_start(
# Initial search to inform decomposition. Just get top 3 fits
search_tool = agent_a_config.search_tool
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]

View File

@ -99,7 +99,8 @@ def generate_initial_answer(
else:
# Use the query info from the base document retrieval
query_info = get_query_info(state.original_question_retrieval_results)
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
for tool_response in yield_search_responses(
query=question,
reranked_sections=relevant_docs,

View File

@ -71,6 +71,8 @@ def generate_refined_answer(
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
query_info = get_query_info(state.original_question_retrieval_results)
if agent_a_config.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
# stream refined answer docs
for tool_response in yield_search_responses(
query=question,

View File

@ -25,7 +25,6 @@ class AgentSearchConfig:
primary_llm: LLM
fast_llm: LLM
search_tool: SearchTool
# Whether to force use of a tool, or to
# force tool args IF the tool is used
@ -37,6 +36,8 @@ class AgentSearchConfig:
# single_message_history: str | None = None
prompt_builder: AnswerPromptBuilder
search_tool: SearchTool | None = None
use_agentic_search: bool = False
# For persisting agent search data
@ -79,6 +80,12 @@ class AgentSearchConfig:
)
return self
@model_validator(mode="after")
def validate_search_tool(self) -> "AgentSearchConfig":
if self.use_agentic_search and self.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
return self
class AgentDocumentCitations(BaseModel):
document_id: str

View File

@ -743,12 +743,19 @@ def stream_chat_message_objects(
# we should construct this unconditionally inside Answer instead
# Leaving it here for the time being to avoid breaking changes
search_tools = [tool for tool in tools if isinstance(tool, SearchTool)]
if len(search_tools) == 0:
raise ValueError("No search tool found")
elif len(search_tools) > 1:
search_tool: SearchTool | None = None
if len(search_tools) > 1:
# TODO: handle multiple search tools
raise ValueError("Multiple search tools found")
search_tool = search_tools[0]
logger.warning("Multiple search tools found, using first one")
search_tool = search_tools[0]
elif len(search_tools) == 1:
search_tool = search_tools[0]
else:
logger.warning("No search tool found")
if new_msg_req.use_agentic_search:
raise ValueError("No search tool found, cannot use agentic search")
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)