From 1ed2881ef05cd62d97f3fc3687301f9c69249e3b Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Thu, 2 Oct 2025 17:25:55 -0700 Subject: [PATCH] templates: fix crash in improperly defined templates (#12483) --- server/images.go | 8 +++-- template/template.go | 71 +++++++++++++++++++++++++++++---------- template/template_test.go | 6 +++- 3 files changed, 65 insertions(+), 20 deletions(-) diff --git a/server/images.go b/server/images.go index 9466b7fb47..d3bd9ffaf3 100644 --- a/server/images.go +++ b/server/images.go @@ -105,12 +105,16 @@ func (m *Model) Capabilities() []model.Capability { builtinParser := parsers.ParserForName(m.Config.Parser) // Check for tools capability - if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { + v, err := m.Template.Vars() + if err != nil { + slog.Warn("model template contains errors", "error", err) + } + if slices.Contains(v, "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { capabilities = append(capabilities, model.CapabilityTools) } // Check for insert capability - if slices.Contains(m.Template.Vars(), "suffix") { + if slices.Contains(v, "suffix") { capabilities = append(capabilities, model.CapabilityInsert) } diff --git a/template/template.go b/template/template.go index f2775b91b3..c90190d7ac 100644 --- a/template/template.go +++ b/template/template.go @@ -148,7 +148,12 @@ func Parse(s string) (*Template, error) { } t := Template{Template: tmpl, raw: s} - if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") { + vars, err := t.Vars() + if err != nil { + return nil, err + } + + if !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") { // touch up the template and append {{ .Response }} tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response) } @@ -160,11 +165,15 @@ func (t *Template) String() string { return t.raw } -func (t *Template) Vars() []string { +func (t *Template) Vars() ([]string, error) { var vars []string for _, tt := range t.Templates() { for _, n := range tt.Root.Nodes { - vars = append(vars, Identifiers(n)...) + v, err := Identifiers(n) + if err != nil { + return vars, err + } + vars = append(vars, v...) } } @@ -173,7 +182,7 @@ func (t *Template) Vars() []string { set[strings.ToLower(n)] = struct{}{} } - return slices.Sorted(maps.Keys(set)) + return slices.Sorted(maps.Keys(set)), nil } func (t *Template) Contains(s string) bool { @@ -244,6 +253,10 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template { func (t *Template) Execute(w io.Writer, v Values) error { system, messages := collate(v.Messages) + vars, err := t.Vars() + if err != nil { + return err + } if v.Prompt != "" && v.Suffix != "" { return t.Template.Execute(w, map[string]any{ "Prompt": v.Prompt, @@ -253,7 +266,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { "ThinkLevel": v.ThinkLevel, "IsThinkSet": v.IsThinkSet, }) - } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { + } else if !v.forceLegacy && slices.Contains(vars, "messages") { return t.Template.Execute(w, map[string]any{ "System": system, "Messages": messages, @@ -329,7 +342,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { return err } - _, err := io.Copy(w, &b) + _, err = io.Copy(w, &b) return err } @@ -358,27 +371,47 @@ func collate(msgs []api.Message) (string, []*api.Message) { } // Identifiers walks the node tree returning any identifiers it finds along the way -func Identifiers(n parse.Node) []string { +func Identifiers(n parse.Node) ([]string, error) { switch n := n.(type) { case *parse.ListNode: var names []string for _, n := range n.Nodes { - names = append(names, Identifiers(n)...) + i, err := Identifiers(n) + if err != nil { + return names, err + } + names = append(names, i...) } - return names + return names, nil case *parse.TemplateNode: + if n.Pipe == nil { + return nil, errors.New("undefined template specified") + } return Identifiers(n.Pipe) case *parse.ActionNode: + if n.Pipe == nil { + return nil, errors.New("undefined action in template") + } return Identifiers(n.Pipe) case *parse.BranchNode: - names := Identifiers(n.Pipe) + if n.Pipe == nil { + return nil, errors.New("undefined branch") + } + names, err := Identifiers(n.Pipe) + if err != nil { + return names, err + } for _, n := range []*parse.ListNode{n.List, n.ElseList} { if n != nil { - names = append(names, Identifiers(n)...) + i, err := Identifiers(n) + if err != nil { + return names, err + } + names = append(names, i...) } } - return names + return names, nil case *parse.IfNode: return Identifiers(&n.BranchNode) case *parse.RangeNode: @@ -389,17 +422,21 @@ func Identifiers(n parse.Node) []string { var names []string for _, c := range n.Cmds { for _, a := range c.Args { - names = append(names, Identifiers(a)...) + i, err := Identifiers(a) + if err != nil { + return names, err + } + names = append(names, i...) } } - return names + return names, nil case *parse.FieldNode: - return n.Ident + return n.Ident, nil case *parse.VariableNode: - return n.Ident + return n.Ident, nil } - return nil + return nil, nil } // deleteNode walks the node list and deletes nodes that match the predicate diff --git a/template/template_test.go b/template/template_test.go index 3d4eb99149..05eacf2d72 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -192,7 +192,11 @@ func TestParse(t *testing.T) { t.Fatal(err) } - if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" { + v, err := tmpl.Vars() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(v, tt.vars); diff != "" { t.Errorf("mismatch (-got +want):\n%s", diff) } })