diff --git a/server/model.go b/server/model.go index 4926d6ce2..644b47b04 100644 --- a/server/model.go +++ b/server/model.go @@ -302,7 +302,7 @@ func parseObjects(s string) []map[string]any { // mxyng: this only really works if the input contains tool calls in some JSON format func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { // create a subtree from the node that ranges over .ToolCalls - tmpl := m.Template.Subtree(func(n parse.Node) bool { + tmpl := m.Template.Sub(func(n parse.Node) bool { if t, ok := n.(*parse.RangeNode); ok { return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") } @@ -315,7 +315,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) { } var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ + if err := tmpl.Template().Execute(&b, map[string][]api.ToolCall{ "ToolCalls": { { Function: api.ToolCallFunction{ diff --git a/template/template.go b/template/template.go index 5c886cac4..85be483ca 100644 --- a/template/template.go +++ b/template/template.go @@ -93,8 +93,8 @@ func Named(s string) (*named, error) { var DefaultTemplate, _ = Parse("{{ .Prompt }}") type Template struct { - *template.Template - raw string + tree *parse.Tree + raw string } // response is a template node that can be added to templates that don't already have one @@ -124,17 +124,18 @@ var funcs = template.FuncMap{ } func Parse(s string) (*Template, error) { - tmpl := template.New("").Option("missingkey=zero").Funcs(funcs) + tree := parse.New("") + tree.Mode = tree.Mode | parse.SkipFuncCheck - tmpl, err := tmpl.Parse(s) + tree, err := tree.Parse(s, "", "", map[string]*parse.Tree{}) if err != nil { return nil, err } - t := Template{Template: tmpl, raw: s} + t := Template{tree, s} if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") { // touch up the template and append {{ .Response }} - tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response) + t.tree.Root.Nodes = append(t.tree.Root.Nodes, &response) } return &t, nil @@ -146,10 +147,8 @@ func (t *Template) String() string { func (t *Template) Vars() []string { var vars []string - for _, tt := range t.Templates() { - for _, n := range tt.Root.Nodes { - vars = append(vars, Identifiers(n)...) - } + for _, n := range t.tree.Root.Nodes { + vars = append(vars, Identifiers(n)...) } set := make(map[string]struct{}) @@ -172,7 +171,8 @@ type Values struct { forceLegacy bool } -func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template { +// Sub returns a new template with the subtree that matches the predicate +func (t *Template) Sub(fn func(parse.Node) bool) *Template { var walk func(parse.Node) parse.Node walk = func(n parse.Node) parse.Node { if fn(n) { @@ -205,29 +205,34 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template { return nil } - if n := walk(t.Tree.Root); n != nil { - return (&template.Template{ - Tree: &parse.Tree{ + if n := walk(t.tree.Root); n != nil { + return &Template{ + tree: &parse.Tree{ Root: &parse.ListNode{ Nodes: []parse.Node{n}, }, }, - }).Funcs(funcs) + } } return nil } +func (t *Template) Template() *template.Template { + return template.Must(template.New("").Option("missingkey=zero").Funcs(funcs).AddParseTree("", t.tree)) +} + func (t *Template) Execute(w io.Writer, v Values) error { + tmpl := t.Template() system, messages := collate(v.Messages) if v.Prompt != "" && v.Suffix != "" { - return t.Template.Execute(w, map[string]any{ + return tmpl.Execute(w, map[string]any{ "Prompt": v.Prompt, "Suffix": v.Suffix, "Response": "", }) } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { - return t.Template.Execute(w, map[string]any{ + return tmpl.Execute(w, map[string]any{ "System": system, "Messages": messages, "Tools": v.Tools, @@ -240,7 +245,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { var prompt, response string for _, m := range messages { execute := func() error { - if err := t.Template.Execute(&b, map[string]any{ + if err := tmpl.Execute(&b, map[string]any{ "System": system, "Prompt": prompt, "Response": response, @@ -275,7 +280,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { } var cut bool - nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool { + nodes := deleteNode(t.tree.Root.Copy(), func(n parse.Node) bool { if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") { cut = true return false @@ -285,7 +290,7 @@ func (t *Template) Execute(w io.Writer, v Values) error { }) tree := parse.Tree{Root: nodes.(*parse.ListNode)} - if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ + if err := template.Must(tmpl.AddParseTree("", &tree)).Execute(&b, map[string]any{ "System": system, "Prompt": prompt, "Response": response, diff --git a/template/template_test.go b/template/template_test.go index 616bef6a8..ce13186cd 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -54,7 +54,7 @@ func TestNamed(t *testing.T) { t.Fatal(err) } - if tmpl.Tree.Root.String() == "" { + if tmpl.tree.Root.String() == "" { t.Errorf("empty %s template", k) } }) @@ -153,7 +153,7 @@ func TestTemplate(t *testing.T) { } } -func TestParse(t *testing.T) { +func TestParseVars(t *testing.T) { cases := []struct { template string vars []string @@ -181,6 +181,9 @@ func TestParse(t *testing.T) { {{ end }}<|im_start|>assistant {{ .Response }}<|im_end|> {{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}}, + {"{{ json .Messages }}", []string{"messages"}}, + // undefined functions should not error + {"{{ undefined }}", []string{"response"}}, } for _, tt := range cases { @@ -197,6 +200,30 @@ func TestParse(t *testing.T) { } } +func TestParseExecute(t *testing.T) { + t.Run("undefined function", func(t *testing.T) { + tmpl, err := Parse(`{{- if .Suffix }}{{ .Prompt }} {{ .Suffix }}{{- else }}{{ undefined }}{{- end }}`) + if err != nil { + t.Fatal(err) + } + + var b bytes.Buffer + if err := tmpl.Execute(&b, Values{Prompt: "def add(", Suffix: " return c"}); err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff(b.String(), "def add( return c"); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + + if err := tmpl.Execute(io.Discard, Values{}); err == nil { + t.Fatal("expected error") + } else if !strings.Contains(err.Error(), "\"undefined\" is not a defined function") { + t.Fatal(err) + } + }) +} + func TestExecuteWithMessages(t *testing.T) { type template struct { name string