From 4f8a0166ccc540346dd160796dacdaceac1fde73 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Wed, 23 Jul 2025 21:21:29 -0700 Subject: [PATCH] tools: loosen tool argument parsing (#11509) --- tools/tools.go | 125 +++++++++++----------------- tools/tools_test.go | 197 +++++++------------------------------------- 2 files changed, 78 insertions(+), 244 deletions(-) diff --git a/tools/tools.go b/tools/tools.go index c149885f6d..f473ab6a63 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -120,16 +120,14 @@ func (p *Parser) parseToolCall() *api.ToolCall { return nil } - // only look for arguments after the tool name if the tool has parameters - // TODO (jmorganca): while probably uncommon, this doesn't support - // parsing arguments before the tool name, which may be needed in the future - args := map[string]any{} - if len(tool.Function.Parameters.Properties) > 0 { - var i int - if args, i = findArguments(*tool, p.buffer[end:]); args == nil { - return nil + var args map[string]any + if found, i := findArguments(p.buffer); found == nil { + return nil + } else { + args = found + if i > end { + end = i } - end += i } tc := &api.ToolCall{ @@ -217,93 +215,70 @@ func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) { // objects for functions that have all-optional parameters // e.g. `{"name": "get_conditions", "arguments": {}}` will work but // `{"name": "get_conditions"}` will not currently work -func findArguments(tool api.Tool, buffer []byte) (map[string]any, int) { +func findArguments(buffer []byte) (map[string]any, int) { if len(buffer) == 0 { return nil, 0 } var braces int var start int = -1 - var end int - var object []byte - // find any outer json object for i, c := range buffer { if c == '{' { - braces++ - if start == -1 { + if braces == 0 { start = i } - } + braces++ + } else if c == '}' && braces > 0 { + braces-- + if braces == 0 && start != -1 { + object := buffer[start : i+1] - if c == '}' { - if start != -1 { - braces-- - if braces == 0 { - end = i + 1 - object = buffer[start:end] - break + var data map[string]any + if err := json.Unmarshal(object, &data); err != nil { + start = -1 + continue } - } - } - } - if braces > 0 { - return nil, 0 - } - - var data map[string]any - if err := json.Unmarshal(object, &data); err != nil { - return nil, 0 - } - - var find func(obj any) map[string]any - find = func(obj any) map[string]any { - switch obj := obj.(type) { - case map[string]any: - valid := true - // check if all keys in the object exist in the tool's parameters - for key := range obj { - if _, exists := tool.Function.Parameters.Properties[key]; !exists { - valid = false - break - } - } - - // check for required parameters - // TODO (jmorganca): this should error instead of silently failing - if valid { - for _, required := range tool.Function.Parameters.Required { - if _, exists := obj[required]; !exists { - valid = false - break + var findObject func(obj map[string]any) (map[string]any, bool) + findObject = func(obj map[string]any) (map[string]any, bool) { + if _, hasName := obj["name"]; hasName { + if args, ok := obj["arguments"].(map[string]any); ok { + return args, true + } + if args, ok := obj["parameters"].(map[string]any); ok { + return args, true + } + return nil, true } - } - } - if valid { - return obj - } + for _, v := range obj { + switch child := v.(type) { + case map[string]any: + if result, found := findObject(child); found { + return result, true + } + case []any: + for _, item := range child { + if childObj, ok := item.(map[string]any); ok { + if result, found := findObject(childObj); found { + return result, true + } + } + } + } + } - for _, value := range obj { - if result := find(value); result != nil { - return result + return nil, false } - } - case []any: - for _, item := range obj { - if result := find(item); result != nil { - return result + + if args, found := findObject(data); found { + return args, i } + + return data, i } } - - return nil - } - - result := find(data) - if result != nil { - return result, end } return nil, 0 diff --git a/tools/tools_test.go b/tools/tools_test.go index 092ae32332..a0f7b6b00c 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -227,13 +227,6 @@ func TestParser(t *testing.T) { }, }, }, - { - name: "invalid arguments", - inputs: []string{`{"name": "get_conditions", "arguments": {"city": "San Francisco"}}`}, - content: "", - tmpl: qwen, - calls: nil, - }, { name: "empty args", inputs: []string{`{"name": "get_conditions", "arguments": {}}`}, @@ -249,13 +242,6 @@ func TestParser(t *testing.T) { }, }, }, - { - name: "missing required args", - inputs: []string{`{"name": "get_temperature", "arguments": {}}`}, - content: "", - tmpl: qwen, - calls: nil, - }, { name: "text before tool call", inputs: []string{`Let me check the weather. {"name": "get_temperature", "arguments": {"city": "New York"}}`}, @@ -273,21 +259,6 @@ func TestParser(t *testing.T) { }, }, }, - { - name: "qwen no args tool call", - inputs: []string{`Let me say hello to the user. I'll use the say_hello tool {"name": "say_hello"}`}, - content: "Let me say hello to the user. I'll use the say_hello tool ", - tmpl: qwen, - calls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Index: 0, - Name: "say_hello", - Arguments: api.ToolCallFunctionArguments{}, - }, - }, - }, - }, { name: "qwen no args with text", inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "}, @@ -521,52 +492,6 @@ func TestParser(t *testing.T) { content: "for { fmt.Println(\"hello\") }", tmpl: json, }, - { - name: "json no args tool call", - inputs: []string{ - "{\"name\": \"say_hello\"}", - }, - content: "", - tmpl: json, - calls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Index: 0, - Name: "say_hello", - Arguments: api.ToolCallFunctionArguments{}, - }, - }, - }, - }, - { - name: "json no args no tool call", - inputs: []string{ - "I'll use the say_hello tool to say hello to the user.", - }, - content: "I'll use the say_hello tool to say hello to the user.", - tmpl: json, - calls: nil, - }, - - // TODO (jmorganca): this is a false positive, we should - // not be parsing this as a tool call - { - name: "json no args false positive", - inputs: []string{ - `{say_hello!!!}`, - }, - content: "", - tmpl: json, - calls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Index: 0, - Name: "say_hello", - Arguments: api.ToolCallFunctionArguments{}, - }, - }, - }, - }, { name: "list multiple", inputs: []string{ @@ -684,26 +609,6 @@ func TestParser(t *testing.T) { tmpl: list, calls: nil, }, - { - name: "list with no arguments", - inputs: []string{ - "[", - "{", - "\"name\": \"say_hello\"", - "}", - }, - content: "", - tmpl: list, - calls: []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Index: 0, - Name: "say_hello", - Arguments: api.ToolCallFunctionArguments{}, - }, - }, - }, - }, { name: "tool name with collision", inputs: []string{ @@ -711,7 +616,7 @@ func TestParser(t *testing.T) { "{", "\"name\": \"say_hello", "_world\",", - "}", + "\"arguments\": {}}", "}", }, content: "", @@ -733,13 +638,13 @@ func TestParser(t *testing.T) { "{", "\"name\": \"say_hello", "_world\",", - "}", + "\"arguments\": {}}", "", "", "{", "\"name\": \"say_hello", "\",", - "}", + "\"arguments\": {}}", "", }, content: "", @@ -773,7 +678,7 @@ func TestParser(t *testing.T) { { name: "tool name with collision non streaming multiple", inputs: []string{ - `{"name": "say_hello"}{"name": "say_hello_world"}`, + `{"name": "say_hello", "arguments": {}}{"name": "say_hello_world", "arguments": {}}`, }, content: "", tmpl: qwen, @@ -797,7 +702,7 @@ func TestParser(t *testing.T) { { name: "tool name with collision non streaming shorter", inputs: []string{ - `{"name": "say_hello"}`, + `{"name": "say_hello", "arguments": {}}`, }, content: "", tmpl: qwen, @@ -814,7 +719,7 @@ func TestParser(t *testing.T) { { name: "tool name with collision non streaming longer", inputs: []string{ - `{"name": "say_hello_world"}`, + `{"name": "say_hello_world", "arguments": {}}`, }, content: "", tmpl: qwen, @@ -871,6 +776,26 @@ func TestParser(t *testing.T) { }, }, }, + { + name: "args before name", + inputs: []string{ + `{"arguments": {"a": "5", "b": "10"}, "name": "add"}`, + }, + content: "", + tmpl: qwen, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "add", + Arguments: api.ToolCallFunctionArguments{ + "a": "5", + "b": "10", + }, + }, + }, + }, + }, } for _, tt := range tests { @@ -1167,75 +1092,25 @@ func TestFindTag(t *testing.T) { } func TestFindArguments(t *testing.T) { - tool := api.Tool{ - Type: "function", - Function: api.ToolFunction{ - Name: "get_temperature", - Description: "Retrieve the temperature for a given location", - Parameters: struct { - Type string `json:"type"` - Defs any `json:"$defs,omitempty"` - Items any `json:"items,omitempty"` - Required []string `json:"required"` - Properties map[string]struct { - Type api.PropertyType `json:"type"` - Items any `json:"items,omitempty"` - Description string `json:"description"` - Enum []any `json:"enum,omitempty"` - } `json:"properties"` - }{ - Type: "object", - Properties: map[string]struct { - Type api.PropertyType `json:"type"` - Items any `json:"items,omitempty"` - Description string `json:"description"` - Enum []any `json:"enum,omitempty"` - }{ - "format": { - Type: api.PropertyType{"string"}, - Description: "The format to return the temperature in", - Enum: []any{"fahrenheit", "celsius"}, - }, - "location": { - Type: api.PropertyType{"string"}, - Description: "The location to get the temperature for", - }, - }, - }, - }, - } - - tool2 := api.Tool{ - Type: "function", - Function: api.ToolFunction{ - Name: "say_hello", - Description: "Say hello to the user", - }, - } - tests := []struct { name string buffer []byte want map[string]any - tool api.Tool }{ { name: "empty string", buffer: []byte{}, want: nil, - tool: tool, }, { name: "whitespace only", buffer: []byte(" \n\t "), want: nil, - tool: tool, }, { name: "unbalanced braces - missing closing", buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`), want: nil, - tool: tool, }, { name: "unbalanced braces - extra closing", @@ -1243,13 +1118,11 @@ func TestFindArguments(t *testing.T) { want: map[string]any{ "format": "fahrenheit", }, - tool: tool, }, { name: "invalid JSON", buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`), want: nil, - tool: tool, }, { name: "valid json", @@ -1258,7 +1131,6 @@ func TestFindArguments(t *testing.T) { "format": "fahrenheit", "location": "San Francisco, CA", }, - tool: tool, }, { name: "valid arguments with special tokens", @@ -1267,16 +1139,14 @@ func TestFindArguments(t *testing.T) { "format": "fahrenheit", "location": "San Francisco, CA", }, - tool: tool, }, { name: "valid arguments in array", - buffer: []byte(`[{"arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`), + buffer: []byte(`[{"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`), want: map[string]any{ "format": "fahrenheit", "location": "San Francisco, CA", }, - tool: tool, }, { name: "nested deep", @@ -1285,7 +1155,6 @@ func TestFindArguments(t *testing.T) { "format": "fahrenheit", "location": "San Francisco, CA", }, - tool: tool, }, { name: "one arg", @@ -1293,7 +1162,6 @@ func TestFindArguments(t *testing.T) { want: map[string]any{ "location": "San Francisco, CA", }, - tool: tool, }, { name: "two args", @@ -1302,13 +1170,6 @@ func TestFindArguments(t *testing.T) { "location": "San Francisco, CA", "format": "fahrenheit", }, - tool: tool, - }, - { - name: "no args", - buffer: []byte(`{"name": "say_hello"}`), - want: nil, - tool: tool2, }, { name: "deepseek", @@ -1316,7 +1177,6 @@ func TestFindArguments(t *testing.T) { want: map[string]any{ "location": "Tokyo", }, - tool: tool, }, { name: "deepseek", @@ -1324,13 +1184,12 @@ func TestFindArguments(t *testing.T) { want: map[string]any{ "location": "Tokyo", }, - tool: tool, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := findArguments(tt.tool, tt.buffer) + got, _ := findArguments(tt.buffer) if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)