diff --git a/template/template_test.go b/template/template_test.go index 05eacf2d72..45101e5aeb 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -154,24 +154,55 @@ func TestTemplate(t *testing.T) { } func TestParse(t *testing.T) { - cases := []struct { + validCases := []struct { + name string template string vars []string }{ - {"{{ .Prompt }}", []string{"prompt", "response"}}, - {"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}}, - {"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}}, - {"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}}, - {"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}}, - {"{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role", "toolname"}}, - {`{{- range .Messages }} + { + name: "PromptOnly", + template: "{{ .Prompt }}", + vars: []string{"prompt", "response"}, + }, + { + name: "SystemAndPrompt", + template: "{{ .System }} {{ .Prompt }}", + vars: []string{"prompt", "response", "system"}, + }, + { + name: "PromptResponseSystem", + template: "{{ .System }} {{ .Prompt }} {{ .Response }}", + vars: []string{"prompt", "response", "system"}, + }, + { + name: "ToolsBlock", + template: "{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", + vars: []string{"prompt", "response", "system", "tools"}, + }, + { + name: "MessagesRange", + template: "{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", + vars: []string{"content", "messages", "role"}, + }, + { + name: "ToolResultConditional", + template: "{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}", + vars: []string{"content", "messages", "role", "toolname"}, + }, + { + name: "MultilineSystemUserAssistant", + template: `{{- range .Messages }} {{- if eq .Role "system" }}SYSTEM: {{- else if eq .Role "user" }}USER: {{- else if eq .Role "assistant" }}ASSISTANT: -{{- else if eq .Role "tool" }}TOOL: +{{- else if eq .Role "tool" }}TOOL: {{- end }} {{ .Content }} -{{- end }}`, []string{"content", "messages", "role"}}, - {`{{- if .Messages }} +{{- end }}`, + vars: []string{"content", "messages", "role"}, + }, + { + name: "ChatMLLike", + template: `{{- if .Messages }} {{- range .Messages }}<|im_start|>{{ .Role }} {{ .Content }}<|im_end|> {{ end }}<|im_start|>assistant @@ -182,22 +213,60 @@ func TestParse(t *testing.T) { {{ .Prompt }}<|im_end|> {{ end }}<|im_start|>assistant {{ .Response }}<|im_end|> -{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}}, +{{- end -}}`, + vars: []string{"content", "messages", "prompt", "response", "role", "system"}, + }, } - for _, tt := range cases { - t.Run("", func(t *testing.T) { + for _, tt := range validCases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tmpl, err := Parse(tt.template) if err != nil { - t.Fatal(err) + t.Fatalf("Parse returned unexpected error: %v", err) } - v, err := tmpl.Vars() + gotVars, err := tmpl.Vars() if err != nil { - t.Fatal(err) + t.Fatalf("Vars returned unexpected error: %v", err) } - if diff := cmp.Diff(v, tt.vars); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) + + if diff := cmp.Diff(gotVars, tt.vars); diff != "" { + t.Errorf("Vars mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestParseError(t *testing.T) { + invalidCases := []struct { + name string + template string + errorStr string + }{ + { + "TemplateNotClosed", + "{{ .Prompt ", + "unclosed action", + }, + { + "Template", + `{{define "x"}}{{template "x"}}{{end}}{{template "x"}}`, + "undefined template specified", + }, + } + + for _, tt := range invalidCases { + t.Run(tt.name, func(t *testing.T) { + _, err := Parse(tt.template) + if err == nil { + t.Fatalf("expected Parse to return an error for an invalid template, got nil") + } + + if !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.errorStr)) { + t.Errorf("unexpected error message.\n got: %q\n want substring (case‑insensitive): %q", err.Error(), tt.errorStr) } }) }