diff --git a/tools/tools.go b/tools/tools.go index efeaeee0cf..a86163baa1 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -18,9 +18,8 @@ const ( ) type Parser struct { - tag string - names []string - properties []string + tag string + tools []api.Tool state toolsState buffer []byte @@ -34,15 +33,10 @@ func NewParser(tmpl *template.Template, tools []api.Tool) *Parser { } func NewParserWithTag(tools []api.Tool, tag string) *Parser { - var p Parser - for _, t := range tools { - p.names = append(p.names, t.Function.Name) - for r := range t.Function.Parameters.Properties { - p.properties = append(p.properties, r) - } + return &Parser{ + tag: tag, + tools: tools, } - p.tag = tag - return &p } // Add processes a string input to parse tool calls and content that @@ -121,36 +115,40 @@ func (p *Parser) findTag() (int, bool) { // parseToolCall finds the next complete tool call in the buffer // incrementing n and advancing the buffer. func (p *Parser) parseToolCall() *api.ToolCall { - var name string var args map[string]any + var tool *api.Tool var end int = len(p.buffer) - // find tool name var i int - for _, n := range p.names { + // find tool name + for _, t := range p.tools { + n := t.Function.Name if i = bytes.Index(p.buffer, []byte(n)); i != -1 { if i+len(n) < end { - name = n + tool = &t end = i + len(n) } } } - if name == "" { + if tool == nil { return nil } - if args, i = p.findArguments(); args == nil { - return nil - } + // only look for arguments if the tool has parameters + if len(tool.Function.Parameters.Properties) > 0 { + if args, i = p.findArguments(*tool); args == nil { + return nil + } - if i > end { - end = i + if i > end { + end = i + } } tc := &api.ToolCall{ Function: api.ToolCallFunction{ - Name: name, + Name: tool.Function.Name, Arguments: args, Index: p.n, }, @@ -162,13 +160,17 @@ func (p *Parser) parseToolCall() *api.ToolCall { } // findArguments returns the first object that appears to be -// arguments and the position where the arguments end, returning nil and 0 if -// an invalid JSON object or non-arguments object is found first -func (p *Parser) findArguments() (map[string]any, int) { +// arguments for the provided tool, returning nil +func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) { if len(p.buffer) == 0 { return nil, 0 } + // no arguments to parse + if len(tool.Function.Parameters.Properties) == 0 { + return nil, 0 + } + var braces int var start int = -1 var end int @@ -184,11 +186,13 @@ func (p *Parser) findArguments() (map[string]any, int) { } if c == '}' { - braces-- - if braces == 0 && start != -1 { - end = i + 1 - object = p.buffer[start:end] - break + if start != -1 { + braces-- + if braces == 0 { + end = i + 1 + object = p.buffer[start:end] + break + } } } } @@ -206,24 +210,27 @@ func (p *Parser) findArguments() (map[string]any, int) { var find func(obj any) map[string]any find = func(obj any) map[string]any { - switch v := obj.(type) { + switch obj := obj.(type) { case map[string]any: - // check if the object keys are valid tool properties - // TODO (jmorganca): check only sets of properties that - // go together instead of the entire set - for _, prop := range p.properties { - if _, exists := v[prop]; exists { - return v + found := true + for key := range obj { + if _, exists := tool.Function.Parameters.Properties[key]; !exists { + found = false + break } } - for _, value := range v { + if found { + return obj + } + + for _, value := range obj { if result := find(value); result != nil { return result } } case []any: - for _, item := range v { + for _, item := range obj { if result := find(item); result != nil { return result } diff --git a/tools/tools_test.go b/tools/tools_test.go index 678641684b..ebf4ad7dce 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -104,6 +104,13 @@ func TestParser(t *testing.T) { }, }, }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "say_hello", + Description: "Say hello", + }, + }, } tests := []struct { @@ -144,6 +151,20 @@ func TestParser(t *testing.T) { }, }, }, + { + name: "invalid arguments", + inputs: []string{`{"name": "get_conditions", "arguments": {"city": "San Francisco"}}`}, + content: "", + tmpl: qwen, + calls: nil, + }, + { + name: "missing args", + inputs: []string{`{"name": "get_conditions"}`}, + content: "", + tmpl: qwen, + calls: nil, + }, { name: "text before tool call", inputs: []string{`Let me check the weather. {"name": "get_temperature", "arguments": {"city": "New York"}}`}, @@ -161,6 +182,27 @@ func TestParser(t *testing.T) { }, }, }, + { + name: "qwen no args tool call", + inputs: []string{`Let me say hello to the user. I'll use the say_hello tool {"name": "say_hello"}`}, + content: "Let me say hello to the user. I'll use the say_hello tool ", + tmpl: qwen, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "say_hello", + }, + }, + }, + }, + { + name: "qwen no args with text", + inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "}, + content: "Let me say hello to the user. I'll use the say_hello tool. ", + tmpl: qwen, + calls: nil, + }, { name: "two tool calls in a list", inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`}, @@ -189,7 +231,7 @@ func TestParser(t *testing.T) { }, }, { - name: "two tool calls", + name: "qwen two tool calls", inputs: []string{`Okay, let's call both tools! {"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`}, content: "Okay, let's call both tools! ", tmpl: qwen, @@ -215,6 +257,29 @@ func TestParser(t *testing.T) { }, }, }, + { + name: "qwen two tool calls one with no args", + inputs: []string{`Let me check the weather. {"name": "say_hello"}{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`}, + content: "Let me check the weather. ", + tmpl: qwen, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "say_hello", + }, + }, + { + Function: api.ToolCallFunction{ + Index: 1, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", + }, + }, + }, + }, + }, { name: "deepseek", inputs: []string{"Wait, I need to call a tool<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"}, @@ -338,6 +403,50 @@ func TestParser(t *testing.T) { content: "for { fmt.Println(\"hello\") }", tmpl: json, }, + { + name: "json no args tool call", + inputs: []string{ + "{\"name\": \"say_hello\"}", + }, + content: "", + tmpl: json, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "say_hello", + }, + }, + }, + }, + { + name: "json no args no tool call", + inputs: []string{ + "I'll use the say_hello tool to say hello to the user.", + }, + content: "I'll use the say_hello tool to say hello to the user.", + tmpl: json, + calls: nil, + }, + + // TODO (jmorganca): this is a false positive, we should + // not be parsing this as a tool call + { + name: "json no args false positive", + inputs: []string{ + `{say_hello!!!}`, + }, + content: "", + tmpl: json, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "say_hello", + }, + }, + }, + }, { name: "list multiple", inputs: []string{ @@ -380,6 +489,30 @@ func TestParser(t *testing.T) { }, { name: "list partial", + inputs: []string{ + "[{", + "\"name\": \"get_conditions\", ", + "\"arguments\": {", + "\"location\": \"Tokyo\"", + "}", + "}", + }, + content: "", + tmpl: list, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", + }, + }, + }, + }, + }, + { + name: "list invalid", inputs: []string{ "[", "{", @@ -393,6 +526,33 @@ func TestParser(t *testing.T) { tmpl: list, calls: nil, }, + { + name: "list trailing ]", + inputs: []string{ + "[", + "{", + "\"name\": \"get_conditions\", ", + "\"arguments\": {", + "\"location\": \"Tokyo\"", + "}", + "}", + "]", + "]", + }, + content: "", + tmpl: list, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", + }, + }, + }, + }, + }, { name: "list not a tool call", inputs: []string{ @@ -404,6 +564,25 @@ func TestParser(t *testing.T) { tmpl: list, calls: nil, }, + { + name: "list with no arguments", + inputs: []string{ + "[", + "{", + "\"name\": \"say_hello\"", + "}", + }, + content: "", + tmpl: list, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "say_hello", + }, + }, + }, + }, } for _, tt := range tests { @@ -700,25 +879,75 @@ func TestFindTag(t *testing.T) { } func TestFindArguments(t *testing.T) { + tool := api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: "get_temperature", + Description: "Retrieve the temperature for a given location", + Parameters: struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Properties: map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + }{ + "format": { + Type: api.PropertyType{"string"}, + Description: "The format to return the temperature in", + Enum: []any{"fahrenheit", "celsius"}, + }, + "location": { + Type: api.PropertyType{"string"}, + Description: "The location to get the temperature for", + }, + }, + }, + }, + } + + tool2 := api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: "say_hello", + Description: "Say hello to the user", + }, + } + tests := []struct { name string buffer []byte want map[string]any + tool api.Tool }{ { name: "empty string", buffer: []byte{}, want: nil, + tool: tool, }, { name: "whitespace only", buffer: []byte(" \n\t "), want: nil, + tool: tool, }, { name: "unbalanced braces - missing closing", buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`), want: nil, + tool: tool, }, { name: "unbalanced braces - extra closing", @@ -726,11 +955,13 @@ func TestFindArguments(t *testing.T) { want: map[string]any{ "format": "fahrenheit", }, + tool: tool, }, { name: "invalid JSON", buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`), want: nil, + tool: tool, }, { name: "valid json", @@ -739,6 +970,7 @@ func TestFindArguments(t *testing.T) { "format": "fahrenheit", "location": "San Francisco, CA", }, + tool: tool, }, { name: "valid arguments with special tokens", @@ -747,6 +979,7 @@ func TestFindArguments(t *testing.T) { "format": "fahrenheit", "location": "San Francisco, CA", }, + tool: tool, }, { name: "valid arguments in array", @@ -755,6 +988,7 @@ func TestFindArguments(t *testing.T) { "format": "fahrenheit", "location": "San Francisco, CA", }, + tool: tool, }, { name: "nested deep", @@ -763,39 +997,49 @@ func TestFindArguments(t *testing.T) { "format": "fahrenheit", "location": "San Francisco, CA", }, + tool: tool, }, { name: "one arg", - buffer: []byte(`get_weather({"location": "San Francisco, CA"})`), + buffer: []byte(`get_temperature({"location": "San Francisco, CA"})`), want: map[string]any{ "location": "San Francisco, CA", }, + tool: tool, }, { name: "two args", - buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`), + buffer: []byte(`[{"name": "get_temperature", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`), want: map[string]any{ "location": "San Francisco, CA", "format": "fahrenheit", }, + tool: tool, + }, + { + name: "no args", + buffer: []byte(`{"name": "say_hello"}`), + want: nil, + tool: tool2, }, { name: "deepseek", - buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"), + buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"), want: map[string]any{ "location": "Tokyo", }, + tool: tool, }, } for _, tt := range tests { parser := &Parser{ - buffer: tt.buffer, - properties: []string{"format", "location"}, + buffer: tt.buffer, + tools: []api.Tool{tool, tool2}, } t.Run(tt.name, func(t *testing.T) { - got, _ := parser.findArguments() + got, _ := parser.findArguments(tool) if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)