mirror of
https://github.com/ollama/ollama.git
synced 2025-07-28 15:03:10 +02:00
tools: loosen tool parsing to allow for more formats (#11030)
This commit is contained in:
470
tools/tools.go
470
tools/tools.go
@@ -1,253 +1,287 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
"text/template"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidToolCall = errors.New("invalid tool call format")
|
||||
errAccumulateMore = errors.New("need to accumulate more content")
|
||||
type toolsState int
|
||||
|
||||
const (
|
||||
toolsState_LookingForTag toolsState = iota
|
||||
toolsState_ToolCalling
|
||||
toolsState_Done
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
greedyParseJSON bool
|
||||
prefix string
|
||||
prefixFound bool
|
||||
tmpl gotmpl.Template
|
||||
sb strings.Builder
|
||||
index int
|
||||
name string
|
||||
arguments string
|
||||
tag string
|
||||
names []string
|
||||
properties []string
|
||||
|
||||
state toolsState
|
||||
buffer []byte
|
||||
n int
|
||||
}
|
||||
|
||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The string to parse
|
||||
// - name: The field name from template that identifies the tool call name
|
||||
// - arguments: The field name from template that identifies the tool call arguments
|
||||
//
|
||||
// Returns:
|
||||
// - []api.ToolCall: The parsed tool calls if successful
|
||||
// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful
|
||||
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) {
|
||||
// Check for balanced braces before attempting to parse
|
||||
braceCount := 0
|
||||
squareCount := 0
|
||||
startIndex := -1
|
||||
var rawToolCalls []string
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case.
|
||||
trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[")
|
||||
for i, c := range s {
|
||||
switch c {
|
||||
case '{':
|
||||
braceCount++
|
||||
if startIndex == -1 {
|
||||
startIndex = i
|
||||
}
|
||||
case '}':
|
||||
braceCount--
|
||||
if braceCount == 0 {
|
||||
rawToolCalls = append(rawToolCalls, s[startIndex:i+1])
|
||||
startIndex = -1
|
||||
}
|
||||
case '[':
|
||||
if trackSquareBrackets {
|
||||
squareCount++
|
||||
}
|
||||
case ']':
|
||||
if trackSquareBrackets {
|
||||
squareCount--
|
||||
}
|
||||
}
|
||||
|
||||
// Negative means we have an extra closing brace/bracket
|
||||
if braceCount < 0 || squareCount < 0 {
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
}
|
||||
|
||||
// If braces/brackets aren't balanced, need more input
|
||||
if braceCount > 0 || squareCount > 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
t := strings.TrimSpace(s)
|
||||
if len(t) == 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
// If the input is a single square bracket, it's not a valid tool call
|
||||
if t[0] == '[' && len(t) == 1 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
// Attempt full unmarshal of the JSON
|
||||
var toolCalls []api.ToolCall
|
||||
for _, rawToolCall := range rawToolCalls {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Collect nested objects that could contain tool calls
|
||||
objs := collect(resp)
|
||||
if len(objs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract tool calls from objects
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
Arguments: a,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
slog.Debug("No valid tool call found in object.", "object", kv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Valid JSON, no tool calls found
|
||||
if len(toolCalls) == 0 {
|
||||
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
|
||||
return toolCalls, nil
|
||||
// NewParser creates a new tool call parser from a model's chat
|
||||
// template and a list of provided tools.
|
||||
func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
|
||||
return NewParserWithTag(tools, parseTag(tmpl))
|
||||
}
|
||||
|
||||
// checkPrefix processes a string to find and handle a prefix pattern.
|
||||
//
|
||||
// Returns:
|
||||
// - The processed string with prefix removed if found
|
||||
// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful
|
||||
func (p *Parser) checkPrefix(s string) (string, error) {
|
||||
if s == "" || p.prefix == "" {
|
||||
return s, nil
|
||||
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
|
||||
var p Parser
|
||||
for _, t := range tools {
|
||||
p.names = append(p.names, t.Function.Name)
|
||||
for r := range t.Function.Parameters.Properties {
|
||||
p.properties = append(p.properties, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for prefix at start of string
|
||||
if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
|
||||
// Found prefix at start - accumulate for potential tool
|
||||
p.prefixFound = true
|
||||
return cut, nil
|
||||
}
|
||||
|
||||
// Check if prefix overlaps end of string
|
||||
if idx := suffixOverlap(s, p.prefix); idx != -1 {
|
||||
// Return everything except overlapping portion
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(s[idx:])
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// Check if prefix appears in middle of string
|
||||
if idx := strings.Index(s, p.prefix); idx != -1 {
|
||||
// Save remainder starting at prefix for next pass
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(strings.TrimSpace(s[idx:]))
|
||||
// Return everything before prefix
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// No partial prefix found
|
||||
return s, nil
|
||||
p.tag = tag
|
||||
return &p
|
||||
}
|
||||
|
||||
// Add processes a string input to parse tool calls and content.
|
||||
// It handles prefix detection and JSON parsing to extract tool calls.
|
||||
//
|
||||
// Returns:
|
||||
// - tools: Any parsed tool calls
|
||||
// - content: Non-tool call content
|
||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
|
||||
// Check for prefix pattern in input
|
||||
s, err := p.checkPrefix(s)
|
||||
if err != nil {
|
||||
// Need more input to complete prefix
|
||||
// Add processes a string input to parse tool calls and content that
|
||||
// should be sent back to the user.
|
||||
func (p *Parser) Add(s string) (calls []api.ToolCall, content string) {
|
||||
if p.state == toolsState_Done {
|
||||
return nil, s
|
||||
}
|
||||
|
||||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||
if !p.greedyParseJSON && !p.prefixFound {
|
||||
p.sb.Reset()
|
||||
return nil, s
|
||||
p.buffer = append(p.buffer, s...)
|
||||
|
||||
if p.state == toolsState_LookingForTag {
|
||||
i, found := p.findTag()
|
||||
if i == -1 {
|
||||
content = string(p.buffer)
|
||||
p.buffer = []byte{}
|
||||
} else {
|
||||
content = string(p.buffer[:i])
|
||||
p.buffer = p.buffer[i:]
|
||||
}
|
||||
|
||||
// for models where { or [ are used as tool calling
|
||||
// tags, we only support parsing tools if the first non-
|
||||
// whitespace character is { or [
|
||||
if p.tag == "{" || p.tag == "[" {
|
||||
if strings.TrimSpace(content) != "" {
|
||||
p.state = toolsState_Done
|
||||
return nil, content + string(p.buffer)
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return nil, content
|
||||
}
|
||||
|
||||
p.state = toolsState_ToolCalling
|
||||
}
|
||||
|
||||
toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix)
|
||||
if err != nil {
|
||||
if errors.Is(err, errAccumulateMore) {
|
||||
return nil, ""
|
||||
for {
|
||||
call := p.parseToolCall()
|
||||
if call == nil {
|
||||
break
|
||||
}
|
||||
p.sb.Reset()
|
||||
// Only do greedy JSON parsing if there is no prefix from template
|
||||
if p.prefix != "" {
|
||||
p.greedyParseJSON = false
|
||||
}
|
||||
if p.index != 0 && p.prefix == "" {
|
||||
return nil, ""
|
||||
}
|
||||
if p.prefixFound {
|
||||
// Drop tokens since prefix was found
|
||||
return nil, ""
|
||||
}
|
||||
return nil, s
|
||||
|
||||
calls = append(calls, *call)
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
tc.Function.Index = p.index
|
||||
p.index++
|
||||
if p.done() {
|
||||
p.state = toolsState_Done
|
||||
content = string(p.buffer)
|
||||
p.buffer = []byte{}
|
||||
}
|
||||
|
||||
p.sb.Reset()
|
||||
return toolCalls, ""
|
||||
return calls, content
|
||||
}
|
||||
|
||||
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
|
||||
// prefix, and field names from the template to use for parsing tool calls from model output.
|
||||
//
|
||||
// Returns an error if the template does not contain valid tool call formatting.
|
||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// findTag searches the buffer to find and handle a tool calling tag
|
||||
// returning true if the tag was found and false otherwise, and
|
||||
// a string content signaling any content that should be sent back to the user
|
||||
func (p *Parser) findTag() (int, bool) {
|
||||
// First check for complete substring anywhere in s
|
||||
if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 {
|
||||
return i, true
|
||||
}
|
||||
|
||||
tt, err := toolTemplate(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Then check for partial suffix overlap
|
||||
max := min(len(p.buffer), len(p.tag))
|
||||
for i := max; i > 0; i-- {
|
||||
if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) {
|
||||
return len(p.buffer) - i, false
|
||||
}
|
||||
}
|
||||
|
||||
tp := toolPrefix(templateToProcess)
|
||||
|
||||
name, arguments, err := extractToolArgs(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Parser{
|
||||
tmpl: *tt,
|
||||
sb: strings.Builder{},
|
||||
prefix: tp,
|
||||
greedyParseJSON: true,
|
||||
name: name,
|
||||
arguments: arguments,
|
||||
}, nil
|
||||
return -1, false
|
||||
}
|
||||
|
||||
// parseToolCall finds the next complete tool call in the buffer
|
||||
// incrementing n and advancing the buffer.
|
||||
func (p *Parser) parseToolCall() *api.ToolCall {
|
||||
var name string
|
||||
var args map[string]any
|
||||
var end int = len(p.buffer)
|
||||
|
||||
// find tool name
|
||||
var i int
|
||||
for _, n := range p.names {
|
||||
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
|
||||
if i+len(n) < end {
|
||||
name = n
|
||||
end = i + len(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if args, i = p.findArguments(); args == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if i > end {
|
||||
end = i
|
||||
}
|
||||
|
||||
tc := &api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
Index: p.n,
|
||||
},
|
||||
}
|
||||
|
||||
p.n++
|
||||
p.buffer = p.buffer[end:]
|
||||
return tc
|
||||
}
|
||||
|
||||
// findArguments returns the first object that appears to be
|
||||
// arguments and the position where the arguments end, returning nil and 0 if
|
||||
// an invalid JSON object or non-arguments object is found first
|
||||
func (p *Parser) findArguments() (map[string]any, int) {
|
||||
if len(p.buffer) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var braces int
|
||||
var start int = -1
|
||||
var end int
|
||||
var object []byte
|
||||
|
||||
// find any outer json object
|
||||
for i, c := range p.buffer {
|
||||
if c == '{' {
|
||||
braces++
|
||||
if start == -1 {
|
||||
start = i
|
||||
}
|
||||
}
|
||||
|
||||
if c == '}' {
|
||||
braces--
|
||||
if braces == 0 && start != -1 {
|
||||
end = i + 1
|
||||
object = p.buffer[start:end]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if braces > 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
|
||||
// not valid json
|
||||
if err := json.Unmarshal(object, &data); err != nil {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
var find func(obj any) map[string]any
|
||||
find = func(obj any) map[string]any {
|
||||
switch v := obj.(type) {
|
||||
case map[string]any:
|
||||
// check if the object keys are valid tool properties
|
||||
// TODO (jmorganca): check only sets of properties that
|
||||
// go together instead of the entire set
|
||||
for _, prop := range p.properties {
|
||||
if _, exists := v[prop]; exists {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range v {
|
||||
if result := find(value); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if result := find(item); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
result := find(data)
|
||||
if result != nil {
|
||||
return result, end
|
||||
}
|
||||
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
// done checks if the parser is done parsing by looking
|
||||
// for closing tag. currently only } and ] are supported
|
||||
// for closing tags as {} or [] pairs may not always
|
||||
// represent tool calls and we need to send the content back
|
||||
func (p *Parser) done() bool {
|
||||
var open, close rune
|
||||
switch p.tag {
|
||||
case "{":
|
||||
open, close = '{', '}'
|
||||
case "[":
|
||||
open, close = '[', ']'
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
var count int
|
||||
for _, c := range p.buffer {
|
||||
if c == byte(open) {
|
||||
count++
|
||||
} else if c == byte(close) {
|
||||
count--
|
||||
if count == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Content returns any remaining content that
|
||||
// should be sent to the user. This should be the empty string
|
||||
// string unless the tag is { or [ and a tool call was not found
|
||||
func (p *Parser) Content() string {
|
||||
if p.n > 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if p.tag == "{" || p.tag == "[" {
|
||||
return string(p.buffer)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
Reference in New Issue
Block a user