From 7d494cd65e75207c4ac7e02c7195fb3a47bd7e38 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Sat, 25 Jan 2025 16:06:11 -0800 Subject: [PATCH] allowed empty Search Tool for non-agentic search --- .../expanded_retrieval/nodes/doc_reranking.py | 2 ++ .../expanded_retrieval/nodes/doc_retrieval.py | 2 ++ .../expanded_retrieval/nodes/format_results.py | 2 ++ .../agent_search/deep_search_a/main/edges.py | 1 + .../deep_search_a/main/nodes/agent_logging.py | 7 ++++--- .../main/nodes/agent_search_start.py | 2 ++ .../main/nodes/generate_initial_answer.py | 3 ++- .../main/nodes/generate_refined_answer.py | 2 ++ backend/onyx/agents/agent_search/models.py | 9 ++++++++- backend/onyx/chat/process_message.py | 17 ++++++++++++----- 10 files changed, 37 insertions(+), 10 deletions(-) diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py index bb868489a..34ba56b86 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_reranking.py @@ -33,6 +33,8 @@ def doc_reranking( agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) question = state.question if state.question else agent_a_config.search_request.query + if agent_a_config.search_tool is None: + raise ValueError("search_tool must be provided for agentic search") with get_session_context_manager() as db_session: _search_query = retrieval_preprocessing( search_request=SearchRequest(query=question), diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py index fa3f904fb..620150206 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/doc_retrieval.py @@ -54,6 +54,8 @@ def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrieval ) query_info = None + if search_tool is None: + raise ValueError("search_tool must be provided for agentic search") # new db session to avoid concurrency issues with get_session_context_manager() as db_session: for tool_response in search_tool.run( diff --git a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py index e64bb1876..a7bd5769f 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agents/agent_search/deep_search_a/expanded_retrieval/nodes/format_results.py @@ -45,6 +45,8 @@ def format_results( # the top 3 for that one. We may want to revisit this. stream_documents = state.expanded_retrieval_results[-1].search_results[:3] + if agent_a_config.search_tool is None: + raise ValueError("search_tool must be provided for agentic search") for tool_response in yield_search_responses( query=state.question, reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/edges.py b/backend/onyx/agents/agent_search/deep_search_a/main/edges.py index 774c2d718..a2dc41b93 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/edges.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/edges.py @@ -30,6 +30,7 @@ def route_initial_tool_choice( if state.tool_choice is not None: if ( agent_config.use_agentic_search + and agent_config.search_tool is not None and state.tool_choice.tool.name == agent_config.search_tool.name ): return "agent_search_start" diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py index 375f6dc63..75692c150 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_logging.py @@ -61,9 +61,10 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput: persona_id = agent_a_config.search_request.persona.id user_id = None - user = agent_a_config.search_tool.user - if user: - user_id = user.id + if agent_a_config.search_tool is not None: + user = agent_a_config.search_tool.user + if user: + user_id = user.id # log the agent metrics if agent_a_config.db_session is not None: diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_search_start.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_search_start.py index 57c32cff7..21fcdd667 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_search_start.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/agent_search_start.py @@ -32,6 +32,8 @@ def agent_search_start( # Initial search to inform decomposition. Just get top 3 fits search_tool = agent_a_config.search_tool + if search_tool is None: + raise ValueError("search_tool must be provided for agentic search") retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question) exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS] diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py index 824475d03..1abc384c8 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_initial_answer.py @@ -99,7 +99,8 @@ def generate_initial_answer( else: # Use the query info from the base document retrieval query_info = get_query_info(state.original_question_retrieval_results) - + if agent_a_config.search_tool is None: + raise ValueError("search_tool must be provided for agentic search") for tool_response in yield_search_responses( query=question, reranked_sections=relevant_docs, diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py index a121994fc..c6dff59f1 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py @@ -71,6 +71,8 @@ def generate_refined_answer( combined_documents = dedup_inference_sections(initial_documents, revised_documents) query_info = get_query_info(state.original_question_retrieval_results) + if agent_a_config.search_tool is None: + raise ValueError("search_tool must be provided for agentic search") # stream refined answer docs for tool_response in yield_search_responses( query=question, diff --git a/backend/onyx/agents/agent_search/models.py b/backend/onyx/agents/agent_search/models.py index d8ba053ee..dac5f70e9 100644 --- a/backend/onyx/agents/agent_search/models.py +++ b/backend/onyx/agents/agent_search/models.py @@ -25,7 +25,6 @@ class AgentSearchConfig: primary_llm: LLM fast_llm: LLM - search_tool: SearchTool # Whether to force use of a tool, or to # force tool args IF the tool is used @@ -37,6 +36,8 @@ class AgentSearchConfig: # single_message_history: str | None = None prompt_builder: AnswerPromptBuilder + search_tool: SearchTool | None = None + use_agentic_search: bool = False # For persisting agent search data @@ -79,6 +80,12 @@ class AgentSearchConfig: ) return self + @model_validator(mode="after") + def validate_search_tool(self) -> "AgentSearchConfig": + if self.use_agentic_search and self.search_tool is None: + raise ValueError("search_tool must be provided for agentic search") + return self + class AgentDocumentCitations(BaseModel): document_id: str diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 7f55275fc..9037fb635 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -743,12 +743,19 @@ def stream_chat_message_objects( # we should construct this unconditionally inside Answer instead # Leaving it here for the time being to avoid breaking changes search_tools = [tool for tool in tools if isinstance(tool, SearchTool)] - if len(search_tools) == 0: - raise ValueError("No search tool found") - elif len(search_tools) > 1: + search_tool: SearchTool | None = None + + if len(search_tools) > 1: # TODO: handle multiple search tools - raise ValueError("Multiple search tools found") - search_tool = search_tools[0] + logger.warning("Multiple search tools found, using first one") + search_tool = search_tools[0] + elif len(search_tools) == 1: + search_tool = search_tools[0] + else: + logger.warning("No search tool found") + if new_msg_req.use_agentic_search: + raise ValueError("No search tool found, cannot use agentic search") + using_tool_calling_llm = explicit_tool_calling_supported( llm.config.model_provider, llm.config.model_name )