tools: fix parsing tool calls without any parameters (#11101)

Fixes issue where tool calls that don't expect any parameters were
not being parsed. This also fixes two additional issues: one where
2+ tool calls would not be correctly parsed, and cases where tool calls
with invalid parameters would still get parsed
This commit is contained in:
Jeffrey Morgan
2025-06-17 10:51:43 -07:00
committed by GitHub
parent 9e125d884c
commit 6bda1d2479
2 changed files with 297 additions and 46 deletions

View File

@@ -18,9 +18,8 @@ const (
) )
type Parser struct { type Parser struct {
tag string tag string
names []string tools []api.Tool
properties []string
state toolsState state toolsState
buffer []byte buffer []byte
@@ -34,15 +33,10 @@ func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
} }
func NewParserWithTag(tools []api.Tool, tag string) *Parser { func NewParserWithTag(tools []api.Tool, tag string) *Parser {
var p Parser return &Parser{
for _, t := range tools { tag: tag,
p.names = append(p.names, t.Function.Name) tools: tools,
for r := range t.Function.Parameters.Properties {
p.properties = append(p.properties, r)
}
} }
p.tag = tag
return &p
} }
// Add processes a string input to parse tool calls and content that // Add processes a string input to parse tool calls and content that
@@ -121,36 +115,40 @@ func (p *Parser) findTag() (int, bool) {
// parseToolCall finds the next complete tool call in the buffer // parseToolCall finds the next complete tool call in the buffer
// incrementing n and advancing the buffer. // incrementing n and advancing the buffer.
func (p *Parser) parseToolCall() *api.ToolCall { func (p *Parser) parseToolCall() *api.ToolCall {
var name string
var args map[string]any var args map[string]any
var tool *api.Tool
var end int = len(p.buffer) var end int = len(p.buffer)
// find tool name
var i int var i int
for _, n := range p.names { // find tool name
for _, t := range p.tools {
n := t.Function.Name
if i = bytes.Index(p.buffer, []byte(n)); i != -1 { if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
if i+len(n) < end { if i+len(n) < end {
name = n tool = &t
end = i + len(n) end = i + len(n)
} }
} }
} }
if name == "" { if tool == nil {
return nil return nil
} }
if args, i = p.findArguments(); args == nil { // only look for arguments if the tool has parameters
return nil if len(tool.Function.Parameters.Properties) > 0 {
} if args, i = p.findArguments(*tool); args == nil {
return nil
}
if i > end { if i > end {
end = i end = i
}
} }
tc := &api.ToolCall{ tc := &api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: name, Name: tool.Function.Name,
Arguments: args, Arguments: args,
Index: p.n, Index: p.n,
}, },
@@ -162,13 +160,17 @@ func (p *Parser) parseToolCall() *api.ToolCall {
} }
// findArguments returns the first object that appears to be // findArguments returns the first object that appears to be
// arguments and the position where the arguments end, returning nil and 0 if // arguments for the provided tool, returning nil
// an invalid JSON object or non-arguments object is found first func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
func (p *Parser) findArguments() (map[string]any, int) {
if len(p.buffer) == 0 { if len(p.buffer) == 0 {
return nil, 0 return nil, 0
} }
// no arguments to parse
if len(tool.Function.Parameters.Properties) == 0 {
return nil, 0
}
var braces int var braces int
var start int = -1 var start int = -1
var end int var end int
@@ -184,11 +186,13 @@ func (p *Parser) findArguments() (map[string]any, int) {
} }
if c == '}' { if c == '}' {
braces-- if start != -1 {
if braces == 0 && start != -1 { braces--
end = i + 1 if braces == 0 {
object = p.buffer[start:end] end = i + 1
break object = p.buffer[start:end]
break
}
} }
} }
} }
@@ -206,24 +210,27 @@ func (p *Parser) findArguments() (map[string]any, int) {
var find func(obj any) map[string]any var find func(obj any) map[string]any
find = func(obj any) map[string]any { find = func(obj any) map[string]any {
switch v := obj.(type) { switch obj := obj.(type) {
case map[string]any: case map[string]any:
// check if the object keys are valid tool properties found := true
// TODO (jmorganca): check only sets of properties that for key := range obj {
// go together instead of the entire set if _, exists := tool.Function.Parameters.Properties[key]; !exists {
for _, prop := range p.properties { found = false
if _, exists := v[prop]; exists { break
return v
} }
} }
for _, value := range v { if found {
return obj
}
for _, value := range obj {
if result := find(value); result != nil { if result := find(value); result != nil {
return result return result
} }
} }
case []any: case []any:
for _, item := range v { for _, item := range obj {
if result := find(item); result != nil { if result := find(item); result != nil {
return result return result
} }

View File

@@ -104,6 +104,13 @@ func TestParser(t *testing.T) {
}, },
}, },
}, },
{
Type: "function",
Function: api.ToolFunction{
Name: "say_hello",
Description: "Say hello",
},
},
} }
tests := []struct { tests := []struct {
@@ -144,6 +151,20 @@ func TestParser(t *testing.T) {
}, },
}, },
}, },
{
name: "invalid arguments",
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {"city": "San Francisco"}}</tool_call>`},
content: "",
tmpl: qwen,
calls: nil,
},
{
name: "missing args",
inputs: []string{`<tool_call>{"name": "get_conditions"}</tool_call>`},
content: "",
tmpl: qwen,
calls: nil,
},
{ {
name: "text before tool call", name: "text before tool call",
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`}, inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
@@ -161,6 +182,27 @@ func TestParser(t *testing.T) {
}, },
}, },
}, },
{
name: "qwen no args tool call",
inputs: []string{`Let me say hello to the user. I'll use the say_hello tool <tool_call>{"name": "say_hello"}</tool_call>`},
content: "Let me say hello to the user. I'll use the say_hello tool ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
},
},
},
},
{
name: "qwen no args with text",
inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "},
content: "Let me say hello to the user. I'll use the say_hello tool. ",
tmpl: qwen,
calls: nil,
},
{ {
name: "two tool calls in a list", name: "two tool calls in a list",
inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`}, inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`},
@@ -189,7 +231,7 @@ func TestParser(t *testing.T) {
}, },
}, },
{ {
name: "two tool calls", name: "qwen two tool calls",
inputs: []string{`Okay, let's call both tools! <tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}</tool_call>`}, inputs: []string{`Okay, let's call both tools! <tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}</tool_call>`},
content: "Okay, let's call both tools! ", content: "Okay, let's call both tools! ",
tmpl: qwen, tmpl: qwen,
@@ -215,6 +257,29 @@ func TestParser(t *testing.T) {
}, },
}, },
}, },
{
name: "qwen two tool calls one with no args",
inputs: []string{`Let me check the weather. <tool_call>{"name": "say_hello"}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`},
content: "Let me check the weather. ",
tmpl: qwen,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{ {
name: "deepseek", name: "deepseek",
inputs: []string{"<think>Wait, I need to call a tool</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"}, inputs: []string{"<think>Wait, I need to call a tool</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"},
@@ -338,6 +403,50 @@ func TestParser(t *testing.T) {
content: "for { fmt.Println(\"hello\") }", content: "for { fmt.Println(\"hello\") }",
tmpl: json, tmpl: json,
}, },
{
name: "json no args tool call",
inputs: []string{
"{\"name\": \"say_hello\"}",
},
content: "",
tmpl: json,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
},
},
},
},
{
name: "json no args no tool call",
inputs: []string{
"I'll use the say_hello tool to say hello to the user.",
},
content: "I'll use the say_hello tool to say hello to the user.",
tmpl: json,
calls: nil,
},
// TODO (jmorganca): this is a false positive, we should
// not be parsing this as a tool call
{
name: "json no args false positive",
inputs: []string{
`{say_hello!!!}`,
},
content: "",
tmpl: json,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
},
},
},
},
{ {
name: "list multiple", name: "list multiple",
inputs: []string{ inputs: []string{
@@ -380,6 +489,30 @@ func TestParser(t *testing.T) {
}, },
{ {
name: "list partial", name: "list partial",
inputs: []string{
"[{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{
name: "list invalid",
inputs: []string{ inputs: []string{
"[", "[",
"{", "{",
@@ -393,6 +526,33 @@ func TestParser(t *testing.T) {
tmpl: list, tmpl: list,
calls: nil, calls: nil,
}, },
{
name: "list trailing ]",
inputs: []string{
"[",
"{",
"\"name\": \"get_conditions\", ",
"\"arguments\": {",
"\"location\": \"Tokyo\"",
"}",
"}",
"]",
"]",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
},
{ {
name: "list not a tool call", name: "list not a tool call",
inputs: []string{ inputs: []string{
@@ -404,6 +564,25 @@ func TestParser(t *testing.T) {
tmpl: list, tmpl: list,
calls: nil, calls: nil,
}, },
{
name: "list with no arguments",
inputs: []string{
"[",
"{",
"\"name\": \"say_hello\"",
"}",
},
content: "",
tmpl: list,
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
},
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -700,25 +879,75 @@ func TestFindTag(t *testing.T) {
} }
func TestFindArguments(t *testing.T) { func TestFindArguments(t *testing.T) {
tool := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "get_temperature",
Description: "Retrieve the temperature for a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Properties: map[string]struct {
Type api.PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}{
"format": {
Type: api.PropertyType{"string"},
Description: "The format to return the temperature in",
Enum: []any{"fahrenheit", "celsius"},
},
"location": {
Type: api.PropertyType{"string"},
Description: "The location to get the temperature for",
},
},
},
},
}
tool2 := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "say_hello",
Description: "Say hello to the user",
},
}
tests := []struct { tests := []struct {
name string name string
buffer []byte buffer []byte
want map[string]any want map[string]any
tool api.Tool
}{ }{
{ {
name: "empty string", name: "empty string",
buffer: []byte{}, buffer: []byte{},
want: nil, want: nil,
tool: tool,
}, },
{ {
name: "whitespace only", name: "whitespace only",
buffer: []byte(" \n\t "), buffer: []byte(" \n\t "),
want: nil, want: nil,
tool: tool,
}, },
{ {
name: "unbalanced braces - missing closing", name: "unbalanced braces - missing closing",
buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`), buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
want: nil, want: nil,
tool: tool,
}, },
{ {
name: "unbalanced braces - extra closing", name: "unbalanced braces - extra closing",
@@ -726,11 +955,13 @@ func TestFindArguments(t *testing.T) {
want: map[string]any{ want: map[string]any{
"format": "fahrenheit", "format": "fahrenheit",
}, },
tool: tool,
}, },
{ {
name: "invalid JSON", name: "invalid JSON",
buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`), buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
want: nil, want: nil,
tool: tool,
}, },
{ {
name: "valid json", name: "valid json",
@@ -739,6 +970,7 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit", "format": "fahrenheit",
"location": "San Francisco, CA", "location": "San Francisco, CA",
}, },
tool: tool,
}, },
{ {
name: "valid arguments with special tokens", name: "valid arguments with special tokens",
@@ -747,6 +979,7 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit", "format": "fahrenheit",
"location": "San Francisco, CA", "location": "San Francisco, CA",
}, },
tool: tool,
}, },
{ {
name: "valid arguments in array", name: "valid arguments in array",
@@ -755,6 +988,7 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit", "format": "fahrenheit",
"location": "San Francisco, CA", "location": "San Francisco, CA",
}, },
tool: tool,
}, },
{ {
name: "nested deep", name: "nested deep",
@@ -763,39 +997,49 @@ func TestFindArguments(t *testing.T) {
"format": "fahrenheit", "format": "fahrenheit",
"location": "San Francisco, CA", "location": "San Francisco, CA",
}, },
tool: tool,
}, },
{ {
name: "one arg", name: "one arg",
buffer: []byte(`get_weather({"location": "San Francisco, CA"})`), buffer: []byte(`get_temperature({"location": "San Francisco, CA"})`),
want: map[string]any{ want: map[string]any{
"location": "San Francisco, CA", "location": "San Francisco, CA",
}, },
tool: tool,
}, },
{ {
name: "two args", name: "two args",
buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`), buffer: []byte(`[{"name": "get_temperature", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
want: map[string]any{ want: map[string]any{
"location": "San Francisco, CA", "location": "San Francisco, CA",
"format": "fahrenheit", "format": "fahrenheit",
}, },
tool: tool,
},
{
name: "no args",
buffer: []byte(`{"name": "say_hello"}`),
want: nil,
tool: tool2,
}, },
{ {
name: "deepseek", name: "deepseek",
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"), buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
want: map[string]any{ want: map[string]any{
"location": "Tokyo", "location": "Tokyo",
}, },
tool: tool,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
parser := &Parser{ parser := &Parser{
buffer: tt.buffer, buffer: tt.buffer,
properties: []string{"format", "location"}, tools: []api.Tool{tool, tool2},
} }
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, _ := parser.findArguments() got, _ := parser.findArguments(tool)
if diff := cmp.Diff(got, tt.want); diff != "" { if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff) t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)