mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 20:38:32 +02:00
allowed empty Search Tool for non-agentic search
This commit is contained in:
@@ -33,6 +33,8 @@ def doc_reranking(
|
|||||||
|
|
||||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||||
question = state.question if state.question else agent_a_config.search_request.query
|
question = state.question if state.question else agent_a_config.search_request.query
|
||||||
|
if agent_a_config.search_tool is None:
|
||||||
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
_search_query = retrieval_preprocessing(
|
_search_query = retrieval_preprocessing(
|
||||||
search_request=SearchRequest(query=question),
|
search_request=SearchRequest(query=question),
|
||||||
|
@@ -54,6 +54,8 @@ def doc_retrieval(state: RetrievalInput, config: RunnableConfig) -> DocRetrieval
|
|||||||
)
|
)
|
||||||
|
|
||||||
query_info = None
|
query_info = None
|
||||||
|
if search_tool is None:
|
||||||
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
# new db session to avoid concurrency issues
|
# new db session to avoid concurrency issues
|
||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
for tool_response in search_tool.run(
|
for tool_response in search_tool.run(
|
||||||
|
@@ -45,6 +45,8 @@ def format_results(
|
|||||||
# the top 3 for that one. We may want to revisit this.
|
# the top 3 for that one. We may want to revisit this.
|
||||||
stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
|
stream_documents = state.expanded_retrieval_results[-1].search_results[:3]
|
||||||
|
|
||||||
|
if agent_a_config.search_tool is None:
|
||||||
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=state.question,
|
query=state.question,
|
||||||
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
|
reranked_sections=state.retrieved_documents, # TODO: rename params. (sections pre-merging here.)
|
||||||
|
@@ -30,6 +30,7 @@ def route_initial_tool_choice(
|
|||||||
if state.tool_choice is not None:
|
if state.tool_choice is not None:
|
||||||
if (
|
if (
|
||||||
agent_config.use_agentic_search
|
agent_config.use_agentic_search
|
||||||
|
and agent_config.search_tool is not None
|
||||||
and state.tool_choice.tool.name == agent_config.search_tool.name
|
and state.tool_choice.tool.name == agent_config.search_tool.name
|
||||||
):
|
):
|
||||||
return "agent_search_start"
|
return "agent_search_start"
|
||||||
|
@@ -61,6 +61,7 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
|||||||
persona_id = agent_a_config.search_request.persona.id
|
persona_id = agent_a_config.search_request.persona.id
|
||||||
|
|
||||||
user_id = None
|
user_id = None
|
||||||
|
if agent_a_config.search_tool is not None:
|
||||||
user = agent_a_config.search_tool.user
|
user = agent_a_config.search_tool.user
|
||||||
if user:
|
if user:
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
|
@@ -32,6 +32,8 @@ def agent_search_start(
|
|||||||
# Initial search to inform decomposition. Just get top 3 fits
|
# Initial search to inform decomposition. Just get top 3 fits
|
||||||
|
|
||||||
search_tool = agent_a_config.search_tool
|
search_tool = agent_a_config.search_tool
|
||||||
|
if search_tool is None:
|
||||||
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
|
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
|
||||||
|
|
||||||
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
|
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
|
||||||
|
@@ -99,7 +99,8 @@ def generate_initial_answer(
|
|||||||
else:
|
else:
|
||||||
# Use the query info from the base document retrieval
|
# Use the query info from the base document retrieval
|
||||||
query_info = get_query_info(state.original_question_retrieval_results)
|
query_info = get_query_info(state.original_question_retrieval_results)
|
||||||
|
if agent_a_config.search_tool is None:
|
||||||
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=question,
|
query=question,
|
||||||
reranked_sections=relevant_docs,
|
reranked_sections=relevant_docs,
|
||||||
|
@@ -71,6 +71,8 @@ def generate_refined_answer(
|
|||||||
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
|
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
|
||||||
|
|
||||||
query_info = get_query_info(state.original_question_retrieval_results)
|
query_info = get_query_info(state.original_question_retrieval_results)
|
||||||
|
if agent_a_config.search_tool is None:
|
||||||
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
# stream refined answer docs
|
# stream refined answer docs
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=question,
|
query=question,
|
||||||
|
@@ -25,7 +25,6 @@ class AgentSearchConfig:
|
|||||||
|
|
||||||
primary_llm: LLM
|
primary_llm: LLM
|
||||||
fast_llm: LLM
|
fast_llm: LLM
|
||||||
search_tool: SearchTool
|
|
||||||
|
|
||||||
# Whether to force use of a tool, or to
|
# Whether to force use of a tool, or to
|
||||||
# force tool args IF the tool is used
|
# force tool args IF the tool is used
|
||||||
@@ -37,6 +36,8 @@ class AgentSearchConfig:
|
|||||||
# single_message_history: str | None = None
|
# single_message_history: str | None = None
|
||||||
prompt_builder: AnswerPromptBuilder
|
prompt_builder: AnswerPromptBuilder
|
||||||
|
|
||||||
|
search_tool: SearchTool | None = None
|
||||||
|
|
||||||
use_agentic_search: bool = False
|
use_agentic_search: bool = False
|
||||||
|
|
||||||
# For persisting agent search data
|
# For persisting agent search data
|
||||||
@@ -79,6 +80,12 @@ class AgentSearchConfig:
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_search_tool(self) -> "AgentSearchConfig":
|
||||||
|
if self.use_agentic_search and self.search_tool is None:
|
||||||
|
raise ValueError("search_tool must be provided for agentic search")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class AgentDocumentCitations(BaseModel):
|
class AgentDocumentCitations(BaseModel):
|
||||||
document_id: str
|
document_id: str
|
||||||
|
@@ -743,12 +743,19 @@ def stream_chat_message_objects(
|
|||||||
# we should construct this unconditionally inside Answer instead
|
# we should construct this unconditionally inside Answer instead
|
||||||
# Leaving it here for the time being to avoid breaking changes
|
# Leaving it here for the time being to avoid breaking changes
|
||||||
search_tools = [tool for tool in tools if isinstance(tool, SearchTool)]
|
search_tools = [tool for tool in tools if isinstance(tool, SearchTool)]
|
||||||
if len(search_tools) == 0:
|
search_tool: SearchTool | None = None
|
||||||
raise ValueError("No search tool found")
|
|
||||||
elif len(search_tools) > 1:
|
if len(search_tools) > 1:
|
||||||
# TODO: handle multiple search tools
|
# TODO: handle multiple search tools
|
||||||
raise ValueError("Multiple search tools found")
|
logger.warning("Multiple search tools found, using first one")
|
||||||
search_tool = search_tools[0]
|
search_tool = search_tools[0]
|
||||||
|
elif len(search_tools) == 1:
|
||||||
|
search_tool = search_tools[0]
|
||||||
|
else:
|
||||||
|
logger.warning("No search tool found")
|
||||||
|
if new_msg_req.use_agentic_search:
|
||||||
|
raise ValueError("No search tool found, cannot use agentic search")
|
||||||
|
|
||||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||||
llm.config.model_provider, llm.config.model_name
|
llm.config.model_provider, llm.config.model_name
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user