mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 03:58:30 +02:00
basic search restructure: WIP on fixing tests
This commit is contained in:
@@ -1,16 +1,12 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
@@ -18,18 +14,9 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -37,7 +24,6 @@ from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -52,7 +38,7 @@ class Answer:
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool,
|
||||
pro_search_config: AgentSearchConfig,
|
||||
agent_search_config: AgentSearchConfig,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
@@ -114,7 +100,7 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
self.pro_search_config = pro_search_config
|
||||
self.agent_search_config = agent_search_config
|
||||
self.db_session = db_session
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
@@ -132,130 +118,132 @@ class Answer:
|
||||
return [tool]
|
||||
|
||||
# TODO: delete the function and move the full body to processed_streamed_output
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
def _get_response(self) -> AnswerStream:
|
||||
# current_llm_call = llm_calls[-1]
|
||||
|
||||
tool, tool_args = None, None
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
if (
|
||||
current_llm_call.force_use_tool.force_use
|
||||
and current_llm_call.force_use_tool.args is not None
|
||||
):
|
||||
tool_name, tool_args = (
|
||||
current_llm_call.force_use_tool.tool_name,
|
||||
current_llm_call.force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(current_llm_call.tools, tool_name)
|
||||
# tool, tool_args = None, None
|
||||
# # handle the case where no decision has to be made; we simply run the tool
|
||||
# if (
|
||||
# current_llm_call.force_use_tool.force_use
|
||||
# and current_llm_call.force_use_tool.args is not None
|
||||
# ):
|
||||
# tool_name, tool_args = (
|
||||
# current_llm_call.force_use_tool.tool_name,
|
||||
# current_llm_call.force_use_tool.args,
|
||||
# )
|
||||
# tool = get_tool_by_name(current_llm_call.tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
chosen_tool_and_args = (
|
||||
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
current_llm_call, self.llm
|
||||
)
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
# # special pre-logic for non-tool calling LLM case
|
||||
# elif not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
# chosen_tool_and_args = (
|
||||
# ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
# current_llm_call, self.llm
|
||||
# )
|
||||
# )
|
||||
# if chosen_tool_and_args:
|
||||
# tool, tool_args = chosen_tool_and_args
|
||||
|
||||
if tool and tool_args:
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
# if tool and tool_args:
|
||||
# dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
# dummy_tool_call_chunk.tool_calls = [
|
||||
# ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
# ]
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
ToolResponseHandler([tool]), None, self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
# response_handler_manager = LLMResponseHandlerManager(
|
||||
# ToolResponseHandler([tool]), None, self.is_cancelled
|
||||
# )
|
||||
# yield from response_handler_manager.handle_llm_response(
|
||||
# iter([dummy_tool_call_chunk])
|
||||
# )
|
||||
|
||||
tmp_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if tmp_call is None:
|
||||
return # no more LLM calls to process
|
||||
current_llm_call = tmp_call
|
||||
# tmp_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
# if tmp_call is None:
|
||||
# return # no more LLM calls to process
|
||||
# current_llm_call = tmp_call
|
||||
|
||||
# if we're skipping gen ai answer generation, we should break
|
||||
# out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# answer, which is a no-no!
|
||||
if (
|
||||
self.skip_gen_ai_answer_generation
|
||||
and not current_llm_call.force_use_tool.force_use
|
||||
):
|
||||
return
|
||||
# # if we're skipping gen ai answer generation, we should break
|
||||
# # out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# # answer, which is a no-no!
|
||||
# if (
|
||||
# self.skip_gen_ai_answer_generation
|
||||
# and not current_llm_call.force_use_tool.force_use
|
||||
# ):
|
||||
# return
|
||||
|
||||
# set up "handlers" to listen to the LLM response stream and
|
||||
# feed back the processed results + handle tool call requests
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
# # set up "handlers" to listen to the LLM response stream and
|
||||
# # feed back the processed results + handle tool call requests
|
||||
# # + figure out what the next LLM call should be
|
||||
# tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], [])
|
||||
# final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
# current_llm_call
|
||||
# ) or ([], [])
|
||||
|
||||
# NEXT: we still want to handle the LLM response stream, but it is now:
|
||||
# 1. handle the tool call requests
|
||||
# 2. feed back the processed results
|
||||
# 3. handle the citations
|
||||
# # NEXT: we still want to handle the LLM response stream, but it is now:
|
||||
# # 1. handle the tool call requests
|
||||
# # 2. feed back the processed results
|
||||
# # 3. handle the citations
|
||||
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
# At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
|
||||
# In the future, we'll want to handle citations and tool calls in the langgraph graph.
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
# # At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
|
||||
# # In the future, we'll want to handle citations and tool calls in the langgraph graph.
|
||||
# response_handler_manager = LLMResponseHandlerManager(
|
||||
# tool_call_handler, answer_handler, self.is_cancelled
|
||||
# )
|
||||
|
||||
# In langgraph, whether we do the basic thing (call llm stream) or pro search
|
||||
# is based on a flag in the pro search config
|
||||
|
||||
if self.pro_search_config.use_agentic_search:
|
||||
if self.pro_search_config.search_request is None:
|
||||
raise ValueError("Search request must be provided for pro search")
|
||||
|
||||
if self.db_session is None:
|
||||
raise ValueError("db_session must be provided for pro search")
|
||||
if self.fast_llm is None:
|
||||
raise ValueError("fast_llm must be provided for pro search")
|
||||
if self.agent_search_config.use_agentic_search:
|
||||
if (
|
||||
self.agent_search_config.db_session is None
|
||||
and self.agent_search_config.use_persistence
|
||||
):
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
|
||||
stream = run_main_graph(
|
||||
config=self.pro_search_config,
|
||||
config=self.agent_search_config,
|
||||
)
|
||||
else:
|
||||
stream = run_basic_graph(
|
||||
config=self.pro_search_config,
|
||||
last_llm_call=current_llm_call,
|
||||
response_handler_manager=response_handler_manager,
|
||||
config=self.agent_search_config,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
return
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
# stream = self.llm.stream(
|
||||
# # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
# prompt=current_llm_call.prompt_builder.build(),
|
||||
# tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
# tool_choice=(
|
||||
# "required"
|
||||
# if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
# else None
|
||||
# ),
|
||||
# structured_response_format=self.answer_style_config.structured_response_format,
|
||||
# )
|
||||
# yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
# new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
# if new_llm_call:
|
||||
# yield from self._get_response(llm_calls + [new_llm_call])
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -263,33 +251,33 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
single_message_history=self.single_message_history,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
raw_user_query=self.question,
|
||||
raw_user_uploaded_files=self.latest_query_files or [],
|
||||
single_message_history=self.single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
)
|
||||
# prompt_builder = AnswerPromptBuilder(
|
||||
# user_message=default_build_user_message(
|
||||
# user_query=self.question,
|
||||
# prompt_config=self.prompt_config,
|
||||
# files=self.latest_query_files,
|
||||
# single_message_history=self.single_message_history,
|
||||
# ),
|
||||
# message_history=self.message_history,
|
||||
# llm_config=self.llm.config,
|
||||
# raw_user_query=self.question,
|
||||
# raw_user_uploaded_files=self.latest_query_files or [],
|
||||
# single_message_history=self.single_message_history,
|
||||
# )
|
||||
# prompt_builder.update_system_prompt(
|
||||
# default_build_system_message(self.prompt_config)
|
||||
# )
|
||||
# llm_call = LLMCall(
|
||||
# prompt_builder=prompt_builder,
|
||||
# tools=self._get_tools_list(),
|
||||
# force_use_tool=self.force_use_tool,
|
||||
# files=self.latest_query_files,
|
||||
# tool_call_info=[],
|
||||
# using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
# )
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
for processed_packet in self._get_response():
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
|
@@ -49,6 +49,7 @@ def prepare_chat_message_request(
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
use_agentic_search: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
@@ -74,6 +75,7 @@ def prepare_chat_message_request(
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -33,6 +33,9 @@ from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
@@ -130,6 +133,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -137,7 +141,6 @@ from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -534,11 +537,8 @@ def stream_chat_message_objects(
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
latest_query_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
|
||||
]
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
@@ -748,17 +748,42 @@ def stream_chat_message_objects(
|
||||
# TODO: handle multiple search tools
|
||||
raise ValueError("Multiple search tools found")
|
||||
search_tool = search_tools[0]
|
||||
pro_search_config = AgentSearchConfig(
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
search_request=search_request,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
force_use_tool = _get_force_search_settings(new_msg_req, tools)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config),
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
raw_user_uploaded_files=latest_query_files or [],
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
agent_search_config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
force_use_tool=force_use_tool,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
use_persistence=True,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
tools=tools,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
files=latest_query_files,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
# TODO: add previous messages, answer style config, tools, etc.
|
||||
@@ -785,9 +810,9 @@ def stream_chat_message_objects(
|
||||
fast_llm=fast_llm,
|
||||
message_history=message_history,
|
||||
tools=tools,
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
force_use_tool=force_use_tool,
|
||||
single_message_history=single_message_history,
|
||||
pro_search_config=pro_search_config,
|
||||
agent_search_config=agent_search_config,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
@@ -84,6 +84,7 @@ class AnswerPromptBuilder:
|
||||
raw_user_query: str,
|
||||
raw_user_uploaded_files: list[InMemoryChatFile],
|
||||
single_message_history: str | None = None,
|
||||
system_message: SystemMessage | None = None,
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
|
||||
@@ -108,7 +109,14 @@ class AnswerPromptBuilder:
|
||||
),
|
||||
)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
|
||||
(
|
||||
system_message,
|
||||
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
if system_message
|
||||
else None
|
||||
)
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(
|
||||
|
@@ -5,6 +5,7 @@ from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
@@ -50,56 +51,12 @@ class ToolResponseHandler:
|
||||
def get_tool_call_for_non_tool_calling_llm(
|
||||
cls, llm_call: LLMCall, llm: LLM
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if llm_call.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name)
|
||||
|
||||
tool_args = (
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(llm_call.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
return get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=llm_call.force_use_tool,
|
||||
tools=llm_call.tools,
|
||||
prompt_builder=llm_call.prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
|
||||
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
@@ -196,3 +153,61 @@ class ToolResponseHandler:
|
||||
self.tool_final_result,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool: ForceUseTool,
|
||||
tools: list[Tool],
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
llm: LLM,
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = get_tool_by_name(tools, force_use_tool.tool_name)
|
||||
|
||||
tool_args = (
|
||||
force_use_tool.args
|
||||
if force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=prompt_builder.raw_user_query,
|
||||
history=prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=tools,
|
||||
query=prompt_builder.raw_user_query,
|
||||
history=prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=prompt_builder.raw_message_history,
|
||||
query=prompt_builder.raw_user_query,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
|
Reference in New Issue
Block a user