mirror of
https://github.com/open-webui/open-webui.git
synced 2025-04-11 13:29:41 +02:00
feat: native tool calling support
This commit is contained in:
parent
7766a08b70
commit
314b674f32
@ -57,6 +57,7 @@ from open_webui.utils.task import (
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from open_webui.utils.misc import (
|
||||
deep_update,
|
||||
get_message_list,
|
||||
add_or_update_system_message,
|
||||
add_or_update_user_message,
|
||||
@ -1126,8 +1127,18 @@ async def process_chat_response(
|
||||
for block in content_blocks:
|
||||
if block["type"] == "text":
|
||||
content = f"{content}{block['content'].strip()}\n"
|
||||
elif block["type"] == "tool":
|
||||
pass
|
||||
elif block["type"] == "tool_calls":
|
||||
attributes = block.get("attributes", {})
|
||||
|
||||
block_content = block.get("content", [])
|
||||
results = block.get("results", [])
|
||||
|
||||
if results:
|
||||
if not raw:
|
||||
content = f'{content}\n<details type="tool_calls" done="true" results="{html.escape(json.dumps(results))}">\n<summary>Tool Executed</summary>\n```json\n{block_content}\n```\n```json\n{results}\n```\n</details>\n'
|
||||
else:
|
||||
if not raw:
|
||||
content = f'{content}\n<details type="tool_calls" done="false">\n<summary>Tool Executing...</summary>\n```json\n{block_content}\n```\n</details>\n'
|
||||
|
||||
elif block["type"] == "reasoning":
|
||||
reasoning_display_content = "\n".join(
|
||||
@ -1254,6 +1265,7 @@ async def process_chat_response(
|
||||
metadata["chat_id"], metadata["message_id"]
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
content = message.get("content", "") if message else ""
|
||||
content_blocks = [
|
||||
{
|
||||
@ -1293,6 +1305,8 @@ async def process_chat_response(
|
||||
nonlocal content
|
||||
nonlocal content_blocks
|
||||
|
||||
response_tool_calls = []
|
||||
|
||||
async for line in response.body_iterator:
|
||||
line = line.decode("utf-8") if isinstance(line, bytes) else line
|
||||
data = line
|
||||
@ -1326,7 +1340,42 @@ async def process_chat_response(
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
value = choices[0].get("delta", {}).get("content")
|
||||
delta = choices[0].get("delta", {})
|
||||
delta_tool_calls = delta.get("tool_calls", None)
|
||||
|
||||
if delta_tool_calls:
|
||||
for delta_tool_call in delta_tool_calls:
|
||||
tool_call_index = delta_tool_call.get("index")
|
||||
|
||||
if tool_call_index is not None:
|
||||
if (
|
||||
len(response_tool_calls)
|
||||
<= tool_call_index
|
||||
):
|
||||
response_tool_calls.append(
|
||||
delta_tool_call
|
||||
)
|
||||
else:
|
||||
delta_name = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("name")
|
||||
delta_arguments = delta_tool_call.get(
|
||||
"function", {}
|
||||
).get("arguments")
|
||||
|
||||
if delta_name:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"]["name"] += delta_name
|
||||
|
||||
if delta_arguments:
|
||||
response_tool_calls[
|
||||
tool_call_index
|
||||
]["function"][
|
||||
"arguments"
|
||||
] += delta_arguments
|
||||
|
||||
value = delta.get("content")
|
||||
|
||||
if value:
|
||||
content = f"{content}{value}"
|
||||
@ -1398,6 +1447,29 @@ async def process_chat_response(
|
||||
if not content_blocks[-1]["content"]:
|
||||
content_blocks.pop()
|
||||
|
||||
if response_tool_calls:
|
||||
tool_calls.append(response_tool_calls)
|
||||
|
||||
if response.background:
|
||||
await response.background()
|
||||
|
||||
await stream_body_handler(response)
|
||||
|
||||
MAX_TOOL_CALL_RETRIES = 5
|
||||
tool_call_retries = 0
|
||||
|
||||
while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES:
|
||||
tool_call_retries += 1
|
||||
|
||||
response_tool_calls = tool_calls.pop(0)
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_calls",
|
||||
"content": response_tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
@ -1407,10 +1479,103 @@ async def process_chat_response(
|
||||
}
|
||||
)
|
||||
|
||||
if response.background:
|
||||
await response.background()
|
||||
tools = metadata.get("tools", {})
|
||||
|
||||
await stream_body_handler(response)
|
||||
results = []
|
||||
for tool_call in response_tool_calls:
|
||||
tool_call_id = tool_call.get("id", "")
|
||||
tool_name = tool_call.get("function", {}).get("name", "")
|
||||
|
||||
tool_function_params = {}
|
||||
try:
|
||||
tool_function_params = json.loads(
|
||||
tool_call.get("function", {}).get("arguments", "{}")
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
|
||||
tool_result = None
|
||||
|
||||
if tool_name in tools:
|
||||
tool = tools[tool_name]
|
||||
spec = tool.get("spec", {})
|
||||
|
||||
try:
|
||||
required_params = spec.get("parameters", {}).get(
|
||||
"required", []
|
||||
)
|
||||
tool_function = tool["callable"]
|
||||
tool_function_params = {
|
||||
k: v
|
||||
for k, v in tool_function_params.items()
|
||||
if k in required_params
|
||||
}
|
||||
tool_result = await tool_function(
|
||||
**tool_function_params
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = str(e)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
content_blocks[-1]["results"] = results
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
"content": "",
|
||||
}
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:completion",
|
||||
"data": {
|
||||
"content": serialize_content_blocks(content_blocks),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
res = await generate_chat_completion(
|
||||
request,
|
||||
{
|
||||
"model": model_id,
|
||||
"stream": True,
|
||||
"messages": [
|
||||
*form_data["messages"],
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": serialize_content_blocks(
|
||||
content_blocks, raw=True
|
||||
),
|
||||
"tool_calls": response_tool_calls,
|
||||
},
|
||||
*[
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": result["tool_call_id"],
|
||||
"content": result["content"],
|
||||
}
|
||||
for result in results
|
||||
],
|
||||
],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if isinstance(res, StreamingResponse):
|
||||
await stream_body_handler(res)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
break
|
||||
|
||||
if DETECT_CODE_INTERPRETER:
|
||||
MAX_RETRIES = 5
|
||||
@ -1472,6 +1637,7 @@ async def process_chat_response(
|
||||
output = str(e)
|
||||
|
||||
content_blocks[-1]["output"] = output
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "text",
|
||||
|
@ -7,6 +7,18 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
import collections.abc
|
||||
|
||||
|
||||
def deep_update(d, u):
|
||||
for k, v in u.items():
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
d[k] = deep_update(d.get(k, {}), v)
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
def get_message_list(messages, message_id):
|
||||
"""
|
||||
Reconstructs a list of messages in order up to the specified message_id.
|
||||
@ -187,6 +199,7 @@ def openai_chat_chunk_message_template(
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion.chunk"
|
||||
|
||||
template["choices"][0]["index"] = 0
|
||||
template["choices"][0]["delta"] = {}
|
||||
|
||||
if content:
|
||||
|
Loading…
x
Reference in New Issue
Block a user