From 44b17d2bfa0073e012679152421c0b69671d380e Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Mon, 30 Jun 2025 08:59:03 -0700 Subject: [PATCH] tools: fix parsing tool calls with empty arguments, missing required fields (#11233) --- tools/tools.go | 50 ++++++++++++++++++++------------- tools/tools_test.go | 68 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 86 insertions(+), 32 deletions(-) diff --git a/tools/tools.go b/tools/tools.go index 8a983e19f..f883bf284 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -134,16 +134,16 @@ func (p *Parser) parseToolCall() *api.ToolCall { return nil } - // only look for arguments if the tool has parameters + // only look for arguments after the tool name if the tool has parameters + // TODO (jmorganca): while probably uncommon, this doesn't support + // parsing arguments before the tool name, which may be needed in the future args := map[string]any{} if len(tool.Function.Parameters.Properties) > 0 { - if args, i = p.findArguments(*tool); args == nil { + if args, i = findArguments(*tool, p.buffer[end:]); args == nil { return nil } - if i > end { - end = i - } + end += i } tc := &api.ToolCall{ @@ -160,14 +160,14 @@ func (p *Parser) parseToolCall() *api.ToolCall { } // findArguments returns the first object that appears to be -// arguments for the provided tool, returning nil -func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) { - if len(p.buffer) == 0 { - return nil, 0 - } - - // no arguments to parse - if len(tool.Function.Parameters.Properties) == 0 { +// arguments for the provided tool in the provided buffer, +// returning nil if no arguments are found. +// TODO (jmorganca): this does not support parsing omitted arguments +// objects for functions that have all-optional parameters +// e.g. `{"name": "get_conditions", "arguments": {}}` will work but +// `{"name": "get_conditions"}` will not currently work +func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) { + if len(buffer) == 0 { return nil, 0 } @@ -177,7 +177,7 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) { var object []byte // find any outer json object - for i, c := range p.buffer { + for i, c := range buffer { if c == '{' { braces++ if start == -1 { @@ -190,7 +190,7 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) { braces-- if braces == 0 { end = i + 1 - object = p.buffer[start:end] + object = buffer[start:end] break } } @@ -202,8 +202,6 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) { } var data map[string]any - - // not valid json if err := json.Unmarshal(object, &data); err != nil { return nil, 0 } @@ -212,15 +210,27 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) { find = func(obj any) map[string]any { switch obj := obj.(type) { case map[string]any: - found := true + valid := true + // check if all keys in the object exist in the tool's parameters for key := range obj { if _, exists := tool.Function.Parameters.Properties[key]; !exists { - found = false + valid = false break } } - if found { + // check for required parameters + // TODO (jmorganca): this should error instead of silently failing + if valid { + for _, required := range tool.Function.Parameters.Required { + if _, exists := obj[required]; !exists { + valid = false + break + } + } + } + + if valid { return obj } diff --git a/tools/tools_test.go b/tools/tools_test.go index 35b583438..8418ab6c3 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -52,7 +52,8 @@ func TestParser(t *testing.T) { Enum []any `json:"enum,omitempty"` } `json:"properties"` }{ - Type: "object", + Type: "object", + Required: []string{"city"}, Properties: map[string]struct { Type api.PropertyType `json:"type"` Items any `json:"items,omitempty"` @@ -159,8 +160,23 @@ func TestParser(t *testing.T) { calls: nil, }, { - name: "missing args", - inputs: []string{`{"name": "get_conditions"}`}, + name: "empty args", + inputs: []string{`{"name": "get_conditions", "arguments": {}}`}, + content: "", + tmpl: qwen, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, + }, + }, + { + name: "missing required args", + inputs: []string{`{"name": "get_temperature", "arguments": {}}`}, content: "", tmpl: qwen, calls: nil, @@ -259,9 +275,9 @@ func TestParser(t *testing.T) { }, }, { - name: "qwen two tool calls one with no args", - inputs: []string{`Let me check the weather. {"name": "say_hello"}{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`}, - content: "Let me check the weather. ", + name: "empty args followed by args", + inputs: []string{`Let me say hello and check the weather. {"name": "say_hello", "arguments": {}}{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}`}, + content: "Let me say hello and check the weather. ", tmpl: qwen, calls: []api.ToolCall{ { @@ -271,6 +287,31 @@ func TestParser(t *testing.T) { Arguments: api.ToolCallFunctionArguments{}, }, }, + { + Function: api.ToolCallFunction{ + Index: 1, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "London", + "format": "fahrenheit", + }, + }, + }, + }, + }, + { + name: "qwen empty followed by args", + inputs: []string{`Let me check the weather. {"name": "get_conditions", "arguments": {}}{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`}, + content: "Let me check the weather. ", + tmpl: qwen, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{}, + }, + }, { Function: api.ToolCallFunction{ Index: 1, @@ -1035,16 +1076,19 @@ func TestFindArguments(t *testing.T) { }, tool: tool, }, + { + name: "deepseek", + buffer: []byte(`", "arguments": {"location": "Tokyo"}}`), + want: map[string]any{ + "location": "Tokyo", + }, + tool: tool, + }, } for _, tt := range tests { - parser := &Parser{ - buffer: tt.buffer, - tools: []api.Tool{tool, tool2}, - } - t.Run(tt.name, func(t *testing.T) { - got, _ := parser.findArguments(tool) + got, _ := findArguments(tt.tool, tt.buffer) if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)