From 4053c489b45432bffd8725f1ca32d6c0a517fa30 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 27 Mar 2025 14:51:39 -0700 Subject: [PATCH] server: enable content streaming with tools --- server/images.go | 41 ++++++++++++++++ server/model_test.go | 111 ++++++++++++++++++++++++++++++++++++------- server/routes.go | 29 ++++++++--- 3 files changed, 159 insertions(+), 22 deletions(-) diff --git a/server/images.go b/server/images.go index bd6d92a6c..e779193c7 100644 --- a/server/images.go +++ b/server/images.go @@ -16,10 +16,12 @@ import ( "net/url" "os" "path/filepath" + "regexp" "runtime" "slices" "strconv" "strings" + "text/template/parse" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" @@ -62,6 +64,7 @@ type Model struct { Digest string Options map[string]any Messages []api.Message + ToolPrefix string Template *template.Template } @@ -350,9 +353,47 @@ func GetModel(name string) (*Model, error) { } } + if model.Template != nil && model.CheckCapabilities(CapabilityTools) == nil { + model.addToolPrefix() + } + return model, nil } +// HasToolPrefix checks if the completion starts with the tool prefix, ignoring whitespace +func (m *Model) HasToolPrefix(sb strings.Builder) bool { + text := regexp.MustCompile(`\s+`).ReplaceAllString(sb.String(), "") + toolString := regexp.MustCompile(`\s+`).ReplaceAllString(m.ToolPrefix, "") + + if len(text) < len(toolString) { + return text == toolString[:len(text)] + } + return text[:len(toolString)] == toolString +} + +// Figure out what's between the start of the tools block, and the json response, and use it as a marker. Usually that's +// {- if .ToolCalls}this text{ range .ToolCalls}or maybe this text{{.name}} +func (m *Model) addToolPrefix() { + // create a subtree from the node that ranges over .ToolCalls + var previousNode parse.Node + toolCallsTemplate := m.Template.Subtree(func(node parse.Node) bool { + if rangeNode, ok := node.(*parse.RangeNode); ok { + return slices.Contains(template.Identifiers(rangeNode.Pipe), "ToolCalls") + } + previousNode = node + return false + }) + if textNode, ok := previousNode.(*parse.TextNode); ok { + m.ToolPrefix = strings.TrimSpace(textNode.String()) + } + if len(m.ToolPrefix) == 0 && len(toolCallsTemplate.Root.Nodes) > 0 { + rangeNode, ok := toolCallsTemplate.Root.Nodes[0].(*parse.RangeNode) + if ok && len(rangeNode.List.Nodes) > 0 { + m.ToolPrefix = rangeNode.List.Nodes[0].String() + } + } +} + func CopyModel(src, dst model.Name) error { if !dst.IsFullyQualified() { return model.Unqualified(dst) diff --git a/server/model_test.go b/server/model_test.go index e5c2f2bb2..0f050011f 100644 --- a/server/model_test.go +++ b/server/model_test.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -28,19 +29,20 @@ func readFile(t *testing.T, base, name string) *bytes.Buffer { func TestExecuteWithTools(t *testing.T) { p := filepath.Join("testdata", "tools") cases := []struct { - model string - output string - ok bool + model string + output string + ok bool + wellFormed bool }{ - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] + {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true}, + {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] -The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, - {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false}, +The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true, false}, + {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false, false}, {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: - [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, + [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, false}, + {"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false}, {"command-r-plus", "Action: ```json" + ` [ { @@ -58,16 +60,17 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, } } ] -` + "```", true}, - {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, - {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, - {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false}, +` + "```", true, true}, + {"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false}, + {"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true}, + {"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false}, {"llama3-groq-tool-use", ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} -`, true}, - {"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true}, - {"nemotron", `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true}, +`, true, true}, + {"xlam", `### Response: +{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true, true}, + {"nemotron", ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} `, true, true}, } var tools []api.Tool @@ -119,6 +122,21 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, } }) + t.Run("prefix", func(t *testing.T) { + m := &Model{Template: tmpl} + m.addToolPrefix() + + if tt.wellFormed { + if len(m.ToolPrefix) == 0 { + t.Fatalf("No tool prefix detected") + } + + if !strings.HasPrefix(strings.TrimSpace(tt.output), m.ToolPrefix) { + t.Fatalf("incorrect tool prefix: \"%s\", \"%s\"", m.ToolPrefix, tt.output) + } + } + }) + t.Run("parse", func(t *testing.T) { m := &Model{Template: tmpl} actual, ok := m.parseToolCalls(tt.output) @@ -177,3 +195,64 @@ func TestParseObjects(t *testing.T) { }) } } + +func TestAddToolPrefix(t *testing.T) { + tests := []struct { + name string + template string + want string + }{ + { + name: "prefix_from_previous_text_node", + template: `Previous text node{{- range .ToolCalls}}{{.name}}{{end}}`, + want: "Previous text node", + }, + { + name: "prefix_from_range_node", + template: `{{- range .ToolCalls}}[TOOL_CALLS]{{.name}}{{end}}`, + want: "[TOOL_CALLS]", + }, + { + name: "prefix_with_extra_whitespace", + template: ` Previous text with spaces {{- range .ToolCalls}}{{.name}}{{end}}`, + want: "Previous text with spaces", + }, + { + name: "prefix_with_newlines", + template: "First line\nSecond line\n{{- range .ToolCalls}}{{.name}}{{end}}", + want: "First line\nSecond line", + }, + { + name: "tool_calls_json_template", + template: `{{ if .Content }}{{ .Content }}{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }} +{{ end }}`, + want: ``, + }, + { + name: "mistral_tool_calls_template", + template: `{{- if .Content }} {{ .Content }} +{{- else if .ToolCalls }}[TOOL_CALLS] [ +{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{- end }}] +{{- end }}`, + want: "[TOOL_CALLS] [", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := template.Parse(tt.template) + if err != nil { + t.Fatalf("failed to parse template: %v", err) + } + + m := &Model{Template: tmpl} + m.addToolPrefix() + + if m.ToolPrefix != tt.want { + t.Errorf("incorrect tool prefix:\ngot: %q\nwant: %q", m.ToolPrefix, tt.want) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index 906426b18..a10285797 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1526,6 +1526,8 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) var sb strings.Builder var toolCallIndex int = 0 + var mightBeTools bool = true + buf := make([]api.ChatResponse, 0) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1551,18 +1553,29 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - // TODO: tool call checking and filtering should be moved outside of this callback once streaming - // however this was a simple change for now without reworking streaming logic of this (and other) - // handlers - if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 { + // If we know we're not streaming + if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 || !mightBeTools { ch <- res return } + sb.WriteString(r.Content) + + // Buffer up responses while we're unsure whether to stream. + buf = append(buf, res) + + // not a tools response, continue streaming. + if !m.HasToolPrefix(sb) { + mightBeTools = false + for _, item := range buf { + ch <- item + } + return + } + // Streaming tool calls: // If tools are recognized, use a flag to track the sending of a tool downstream // This ensures that content is cleared from the message on the last chunk sent - sb.WriteString(r.Content) if toolCalls, ok := m.parseToolCalls(sb.String()); ok { res.Message.ToolCalls = toolCalls for i := range toolCalls { @@ -1573,8 +1586,12 @@ func (s *Server) ChatHandler(c *gin.Context) { sb.Reset() ch <- res return + } else { + if !strings.HasPrefix(sb.String(), "{") { + ch <- res + return + } } - if r.Done { // Send any remaining content if no tool calls were detected if toolCallIndex == 0 {