mirror of
https://github.com/ollama/ollama.git
synced 2025-04-11 21:29:32 +02:00
template: disable func checking
This commit is contained in:
parent
2ddc32d5c5
commit
aac17d5f15
@ -302,7 +302,7 @@ func parseObjects(s string) []map[string]any {
|
||||
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||
// create a subtree from the node that ranges over .ToolCalls
|
||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||
tmpl := m.Template.Sub(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
}
|
||||
@ -315,7 +315,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
if err := tmpl.Template().Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
|
@ -93,8 +93,8 @@ func Named(s string) (*named, error) {
|
||||
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
||||
|
||||
type Template struct {
|
||||
*template.Template
|
||||
raw string
|
||||
tree *parse.Tree
|
||||
raw string
|
||||
}
|
||||
|
||||
// response is a template node that can be added to templates that don't already have one
|
||||
@ -124,17 +124,18 @@ var funcs = template.FuncMap{
|
||||
}
|
||||
|
||||
func Parse(s string) (*Template, error) {
|
||||
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
|
||||
tree := parse.New("")
|
||||
tree.Mode = tree.Mode | parse.SkipFuncCheck
|
||||
|
||||
tmpl, err := tmpl.Parse(s)
|
||||
tree, err := tree.Parse(s, "", "", map[string]*parse.Tree{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := Template{Template: tmpl, raw: s}
|
||||
t := Template{tree, s}
|
||||
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
|
||||
// touch up the template and append {{ .Response }}
|
||||
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
|
||||
t.tree.Root.Nodes = append(t.tree.Root.Nodes, &response)
|
||||
}
|
||||
|
||||
return &t, nil
|
||||
@ -146,10 +147,8 @@ func (t *Template) String() string {
|
||||
|
||||
func (t *Template) Vars() []string {
|
||||
var vars []string
|
||||
for _, tt := range t.Templates() {
|
||||
for _, n := range tt.Root.Nodes {
|
||||
vars = append(vars, Identifiers(n)...)
|
||||
}
|
||||
for _, n := range t.tree.Root.Nodes {
|
||||
vars = append(vars, Identifiers(n)...)
|
||||
}
|
||||
|
||||
set := make(map[string]struct{})
|
||||
@ -172,7 +171,8 @@ type Values struct {
|
||||
forceLegacy bool
|
||||
}
|
||||
|
||||
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
||||
// Sub returns a new template with the subtree that matches the predicate
|
||||
func (t *Template) Sub(fn func(parse.Node) bool) *Template {
|
||||
var walk func(parse.Node) parse.Node
|
||||
walk = func(n parse.Node) parse.Node {
|
||||
if fn(n) {
|
||||
@ -205,29 +205,34 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
||||
return nil
|
||||
}
|
||||
|
||||
if n := walk(t.Tree.Root); n != nil {
|
||||
return (&template.Template{
|
||||
Tree: &parse.Tree{
|
||||
if n := walk(t.tree.Root); n != nil {
|
||||
return &Template{
|
||||
tree: &parse.Tree{
|
||||
Root: &parse.ListNode{
|
||||
Nodes: []parse.Node{n},
|
||||
},
|
||||
},
|
||||
}).Funcs(funcs)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Template) Template() *template.Template {
|
||||
return template.Must(template.New("").Option("missingkey=zero").Funcs(funcs).AddParseTree("", t.tree))
|
||||
}
|
||||
|
||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
tmpl := t.Template()
|
||||
system, messages := collate(v.Messages)
|
||||
if v.Prompt != "" && v.Suffix != "" {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
return tmpl.Execute(w, map[string]any{
|
||||
"Prompt": v.Prompt,
|
||||
"Suffix": v.Suffix,
|
||||
"Response": "",
|
||||
})
|
||||
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
return tmpl.Execute(w, map[string]any{
|
||||
"System": system,
|
||||
"Messages": messages,
|
||||
"Tools": v.Tools,
|
||||
@ -240,7 +245,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
var prompt, response string
|
||||
for _, m := range messages {
|
||||
execute := func() error {
|
||||
if err := t.Template.Execute(&b, map[string]any{
|
||||
if err := tmpl.Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
@ -275,7 +280,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
}
|
||||
|
||||
var cut bool
|
||||
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
||||
nodes := deleteNode(t.tree.Root.Copy(), func(n parse.Node) bool {
|
||||
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
||||
cut = true
|
||||
return false
|
||||
@ -285,7 +290,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
})
|
||||
|
||||
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||
if err := template.Must(tmpl.AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
|
@ -54,7 +54,7 @@ func TestNamed(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if tmpl.Tree.Root.String() == "" {
|
||||
if tmpl.tree.Root.String() == "" {
|
||||
t.Errorf("empty %s template", k)
|
||||
}
|
||||
})
|
||||
@ -153,7 +153,7 @@ func TestTemplate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
func TestParseVars(t *testing.T) {
|
||||
cases := []struct {
|
||||
template string
|
||||
vars []string
|
||||
@ -181,6 +181,9 @@ func TestParse(t *testing.T) {
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>
|
||||
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
|
||||
{"{{ json .Messages }}", []string{"messages"}},
|
||||
// undefined functions should not error
|
||||
{"{{ undefined }}", []string{"response"}},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
@ -197,6 +200,30 @@ func TestParse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseExecute(t *testing.T) {
|
||||
t.Run("undefined function", func(t *testing.T) {
|
||||
tmpl, err := Parse(`{{- if .Suffix }}{{ .Prompt }} {{ .Suffix }}{{- else }}{{ undefined }}{{- end }}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, Values{Prompt: "def add(", Suffix: " return c"}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(b.String(), "def add( return c"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(io.Discard, Values{}); err == nil {
|
||||
t.Fatal("expected error")
|
||||
} else if !strings.Contains(err.Error(), "\"undefined\" is not a defined function") {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExecuteWithMessages(t *testing.T) {
|
||||
type template struct {
|
||||
name string
|
||||
|
Loading…
x
Reference in New Issue
Block a user