feat: native tool calling support

This commit is contained in:
Timothy Jaeryang Baek 2025-02-04 23:05:14 -08:00
parent 7766a08b70
commit 314b674f32
2 changed files with 185 additions and 6 deletions

View File

@ -57,6 +57,7 @@ from open_webui.utils.task import (
tools_function_calling_generation_template,
)
from open_webui.utils.misc import (
deep_update,
get_message_list,
add_or_update_system_message,
add_or_update_user_message,
@ -1126,8 +1127,18 @@ async def process_chat_response(
for block in content_blocks:
if block["type"] == "text":
content = f"{content}{block['content'].strip()}\n"
elif block["type"] == "tool":
pass
elif block["type"] == "tool_calls":
attributes = block.get("attributes", {})
block_content = block.get("content", [])
results = block.get("results", [])
if results:
if not raw:
content = f'{content}\n<details type="tool_calls" done="true" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n```json\n{block_content}\n```\n```json\n{results}\n```\n</details>\n'
else:
if not raw:
content = f'{content}\n<details type="tool_calls" done="false">\n<summary>Tool Executing...</summary>\n```json\n{block_content}\n```\n</details>\n'
elif block["type"] == "reasoning":
reasoning_display_content = "\n".join(
@ -1254,6 +1265,7 @@ async def process_chat_response(
metadata["chat_id"], metadata["message_id"]
)
tool_calls = []
content = message.get("content", "") if message else ""
content_blocks = [
{
@ -1293,6 +1305,8 @@ async def process_chat_response(
nonlocal content
nonlocal content_blocks
response_tool_calls = []
async for line in response.body_iterator:
line = line.decode("utf-8") if isinstance(line, bytes) else line
data = line
@ -1326,7 +1340,42 @@ async def process_chat_response(
if not choices:
continue
value = choices[0].get("delta", {}).get("content")
delta = choices[0].get("delta", {})
delta_tool_calls = delta.get("tool_calls", None)
if delta_tool_calls:
for delta_tool_call in delta_tool_calls:
tool_call_index = delta_tool_call.get("index")
if tool_call_index is not None:
if (
len(response_tool_calls)
<= tool_call_index
):
response_tool_calls.append(
delta_tool_call
)
else:
delta_name = delta_tool_call.get(
"function", {}
).get("name")
delta_arguments = delta_tool_call.get(
"function", {}
).get("arguments")
if delta_name:
response_tool_calls[
tool_call_index
]["function"]["name"] += delta_name
if delta_arguments:
response_tool_calls[
tool_call_index
]["function"][
"arguments"
] += delta_arguments
value = delta.get("content")
if value:
content = f"{content}{value}"
@ -1398,6 +1447,29 @@ async def process_chat_response(
if not content_blocks[-1]["content"]:
content_blocks.pop()
if response_tool_calls:
tool_calls.append(response_tool_calls)
if response.background:
await response.background()
await stream_body_handler(response)
MAX_TOOL_CALL_RETRIES = 5
tool_call_retries = 0
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
tool_call_retries += 1
response_tool_calls = tool_calls.pop(0)
content_blocks.append(
{
"type": "tool_calls",
"content": response_tool_calls,
}
)
await event_emitter(
{
"type": "chat:completion",
@ -1407,10 +1479,103 @@ async def process_chat_response(
}
)
if response.background:
await response.background()
tools = metadata.get("tools", {})
await stream_body_handler(response)
results = []
for tool_call in response_tool_calls:
tool_call_id = tool_call.get("id", "")
tool_name = tool_call.get("function", {}).get("name", "")
tool_function_params = {}
try:
tool_function_params = json.loads(
tool_call.get("function", {}).get("arguments", "{}")
)
except Exception as e:
log.debug(e)
tool_result = None
if tool_name in tools:
tool = tools[tool_name]
spec = tool.get("spec", {})
try:
required_params = spec.get("parameters", {}).get(
"required", []
)
tool_function = tool["callable"]
tool_function_params = {
k: v
for k, v in tool_function_params.items()
if k in required_params
}
tool_result = await tool_function(
**tool_function_params
)
except Exception as e:
tool_result = str(e)
results.append(
{
"tool_call_id": tool_call_id,
"content": tool_result,
}
)
content_blocks[-1]["results"] = results
content_blocks.append(
{
"type": "text",
"content": "",
}
)
await event_emitter(
{
"type": "chat:completion",
"data": {
"content": serialize_content_blocks(content_blocks),
},
}
)
try:
res = await generate_chat_completion(
request,
{
"model": model_id,
"stream": True,
"messages": [
*form_data["messages"],
{
"role": "assistant",
"content": serialize_content_blocks(
content_blocks, raw=True
),
"tool_calls": response_tool_calls,
},
*[
{
"role": "tool",
"tool_call_id": result["tool_call_id"],
"content": result["content"],
}
for result in results
],
],
},
user,
)
if isinstance(res, StreamingResponse):
await stream_body_handler(res)
else:
break
except Exception as e:
log.debug(e)
break
if DETECT_CODE_INTERPRETER:
MAX_RETRIES = 5
@ -1472,6 +1637,7 @@ async def process_chat_response(
output = str(e)
content_blocks[-1]["output"] = output
content_blocks.append(
{
"type": "text",

View File

@ -7,6 +7,18 @@ from pathlib import Path
from typing import Callable, Optional
import collections.abc
def deep_update(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update(d.get(k, {}), v)
else:
d[k] = v
return d
def get_message_list(messages, message_id):
"""
Reconstructs a list of messages in order up to the specified message_id.
@ -187,6 +199,7 @@ def openai_chat_chunk_message_template(
template = openai_chat_message_template(model)
template["object"] = "chat.completion.chunk"
template["choices"][0]["index"] = 0
template["choices"][0]["delta"] = {}
if content: