second clean commit

This commit is contained in:
Evan Lohn
2025-01-19 18:24:26 -08:00
parent 715359c120
commit 4fd6e36c2f
96 changed files with 7558 additions and 255 deletions

View File

@@ -25,6 +25,13 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_tool_by_name(tools: list[Tool], tool_name: str) -> Tool:
for tool in tools:
if tool.name == tool_name:
return tool
raise RuntimeError(f"Tool '{tool_name}' not found")
class ToolResponseHandler:
def __init__(self, tools: list[Tool]):
self.tools = tools
@@ -45,18 +52,7 @@ class ToolResponseHandler:
) -> tuple[Tool, dict] | None:
if llm_call.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
(
t
for t in llm_call.tools
if t.name == llm_call.force_use_tool.tool_name
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{llm_call.force_use_tool.tool_name}' not found"
)
tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name)
tool_args = (
llm_call.force_use_tool.args
@@ -118,20 +114,17 @@ class ToolResponseHandler:
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
continue
else:
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
if selected_tool and selected_tool_call_request:
break
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
return
@@ -157,8 +150,8 @@ class ToolResponseHandler:
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
response_item: BaseMessage | str | None,
previous_response_items: list[BaseMessage | str],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self._handle_tool_call()
@@ -171,8 +164,6 @@ class ToolResponseHandler:
else:
self.tool_call_chunk += response_item # type: ignore
return
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
if (
self.tool_runner is None