diff --git a/tools/tools.go b/tools/tools.go index f9a2d3b9b6..7b8d726b0e 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -125,7 +125,7 @@ func (p *Parser) parseToolCall() *api.ToolCall { } var args map[string]any - if found, i := findArguments(p.buffer); found == nil { + if found, i := findArguments(tool, p.buffer); found == nil { return nil } else { args = found @@ -219,7 +219,7 @@ func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) { // objects for functions that have all-optional parameters // e.g. `{"name": "get_conditions", "arguments": {}}` will work but // `{"name": "get_conditions"}` will not currently work -func findArguments(buffer []byte) (map[string]any, int) { +func findArguments(tool *api.Tool, buffer []byte) (map[string]any, int) { if len(buffer) == 0 { return nil, 0 } @@ -269,27 +269,30 @@ func findArguments(buffer []byte) (map[string]any, int) { var findObject func(obj map[string]any) (map[string]any, bool) findObject = func(obj map[string]any) (map[string]any, bool) { + findMap := func(name string, obj map[string]any) (map[string]any, bool) { + if args, ok := obj[name].(map[string]any); ok { + return args, true + } + if argsStr, ok := obj[name].(string); ok { + var argsData map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil { + return argsData, ok + } + } + return nil, false + } if _, hasName := obj["name"]; hasName { - if args, ok := obj["arguments"].(map[string]any); ok { + if args, ok := findMap("arguments", obj); ok { return args, true } - if argsStr, ok := obj["arguments"].(string); ok { - var argsData map[string]interface{} - if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil { - return argsData, ok - } - } - if args, ok := obj["parameters"].(map[string]any); ok { + if args, ok := findMap("parameters", obj); ok { return args, true } - if argsStr, ok := obj["parameters"].(string); ok { - var argsData map[string]interface{} - if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil { - return argsData, ok - } - } return nil, true } + if args, ok := findMap(tool.Function.Name, obj); ok { + return args, true + } for _, v := range obj { switch child := v.(type) { diff --git a/tools/tools_test.go b/tools/tools_test.go index 288fa73c55..b849e21944 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -1033,6 +1033,7 @@ func TestFindArguments(t *testing.T) { name string buffer []byte want map[string]any + tool string }{ { name: "empty string", @@ -1290,11 +1291,29 @@ func TestFindArguments(t *testing.T) { "location": "San Francisco, CA", }, }, + { + name: "simple tool call", + tool: "get_temperature", + buffer: []byte(`{"get_temperature": {"format": "fahrenheit", "location": "San Francisco, CA"}}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "stringified simple tool call", + tool: "get_temperature", + buffer: []byte(`{"get_temperature": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, _ := findArguments(tt.buffer) + got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer) if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)