diff --git a/tools/tools.go b/tools/tools.go index 529bd3be3..914a5eaf0 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -17,15 +17,14 @@ var ( ) type Parser struct { - parseLeadingJSON bool - prefix string - prefixFound bool - tmpl gotmpl.Template - sb strings.Builder - index int - name string - arguments string - done bool + greedyParseJSON bool + prefix string + prefixFound bool + tmpl gotmpl.Template + sb strings.Builder + index int + name string + arguments string } // parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. @@ -176,14 +175,6 @@ func (p *Parser) checkPrefix(s string) (string, error) { // - tools: Any parsed tool calls // - content: Non-tool call content func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { - if p.done { - if p.index == 0 { - // Return original string if no tool calls found at start - return nil, s - } - // Return empty if no tool calls found after start - return nil, "" - } p.sb.WriteString(s) s = p.sb.String() @@ -195,7 +186,7 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { } // Exit if prefix exists in template, greedy parsing is off, and prefix not found - if !p.parseLeadingJSON && !p.prefixFound { + if !p.greedyParseJSON && !p.prefixFound { p.sb.Reset() return nil, s } @@ -206,10 +197,9 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { return nil, "" } p.sb.Reset() - // Do not try parsing leading JSON if JSON not found - p.parseLeadingJSON = false - if p.prefix == "" { - p.done = true + // Only do greedy JSON parsing if there is no prefix from template + if p.prefix != "" { + p.greedyParseJSON = false } if p.index != 0 && p.prefix == "" { return nil, "" @@ -253,11 +243,11 @@ func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { } return &Parser{ - tmpl: *tt, - sb: strings.Builder{}, - prefix: tp, - parseLeadingJSON: true, - name: name, - arguments: arguments, + tmpl: *tt, + sb: strings.Builder{}, + prefix: tp, + greedyParseJSON: true, + name: name, + arguments: arguments, }, nil } diff --git a/tools/tools_test.go b/tools/tools_test.go index 1ae3bff89..5fee8f57d 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -536,11 +536,18 @@ func TestParseToolCalls(t *testing.T) { expectedTokens: "", }, { - name: "model without prefix in template, prefix in output", + name: "model without prefix in template, prefix in output, multiple tool calls in list", model: "llama3.2", output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: ``, + }, + { + name: "model without prefix in template, prefix in output, individual tool calls", + model: "llama3.2", + output: ` {"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: ``, }, { name: "model with prefix in template, no prefix in output, tokens before", @@ -567,15 +574,37 @@ func TestParseToolCalls(t *testing.T) { name: "model without prefix in template, no prefix in output, tokens before", model: "llama3.2", output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: `some tokens before`, }, { - name: "model without prefix in template, prefix in output, tokens after", + name: "model without prefix in template, prefix in output, tokens after", + model: "llama3.2", + output: ` + [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: ``, + }, + { + name: "model without without prefix, match all jsons", model: "llama3.2", - output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + output: `model outputs some text [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "model outputs some text", + }, + { + name: "model flushes tokens if tool call doesn't match", + model: "llama3.2", + output: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`, expectedToolCall: []api.ToolCall{}, - expectedTokens: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, + expectedTokens: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`, + }, + { + name: "model flushes tokens if tool call doesn't match array", + model: "llama3.2", + output: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`, }, }