mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 11:57:34 +01:00
move thinking logic into its own package (#10990)
move thinking logic into its own package
This commit is contained in:
@@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/ollama/ollama/thinking"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
@@ -113,7 +114,7 @@ func (m *Model) Capabilities() []model.Capability {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for thinking capability
|
// Check for thinking capability
|
||||||
openingTag, closingTag := inferThinkingTags(m.Template.Template)
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||||
if openingTag != "" && closingTag != "" {
|
if openingTag != "" && closingTag != "" {
|
||||||
capabilities = append(capabilities, model.CapabilityThinking)
|
capabilities = append(capabilities, model.CapabilityThinking)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import (
|
|||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/ollama/ollama/thinking"
|
||||||
"github.com/ollama/ollama/tools"
|
"github.com/ollama/ollama/tools"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
@@ -282,10 +283,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
prompt = b.String()
|
prompt = b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
var thinkingState *ThinkingParser
|
var thinkingState *thinking.Parser
|
||||||
openingTag, closingTag := inferThinkingTags(m.Template.Template)
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||||
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
||||||
thinkingState = &ThinkingParser{
|
thinkingState = &thinking.Parser{
|
||||||
OpeningTag: openingTag,
|
OpeningTag: openingTag,
|
||||||
ClosingTag: closingTag,
|
ClosingTag: closingTag,
|
||||||
}
|
}
|
||||||
@@ -1522,10 +1523,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var thinkingState *ThinkingParser
|
var thinkingState *thinking.Parser
|
||||||
openingTag, closingTag := inferThinkingTags(m.Template.Template)
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||||
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
||||||
thinkingState = &ThinkingParser{
|
thinkingState = &thinking.Parser{
|
||||||
OpeningTag: openingTag,
|
OpeningTag: openingTag,
|
||||||
ClosingTag: closingTag,
|
ClosingTag: closingTag,
|
||||||
}
|
}
|
||||||
@@ -1676,7 +1677,7 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
|||||||
// change the user output), we should probably perform this filtering
|
// change the user output), we should probably perform this filtering
|
||||||
// for all thinking models (not just qwen3 & deepseek-r1) since it tends
|
// for all thinking models (not just qwen3 & deepseek-r1) since it tends
|
||||||
// to save tokens and improve quality.
|
// to save tokens and improve quality.
|
||||||
thinkingState := &ThinkingParser{
|
thinkingState := &thinking.Parser{
|
||||||
OpeningTag: "<think>",
|
OpeningTag: "<think>",
|
||||||
ClosingTag: "</think>",
|
ClosingTag: "</think>",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
package server
|
package thinking
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
|
||||||
"text/template/parse"
|
|
||||||
"unicode"
|
"unicode"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -46,7 +44,7 @@ func (s thinkingState) String() string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ThinkingParser struct {
|
type Parser struct {
|
||||||
state thinkingState
|
state thinkingState
|
||||||
OpeningTag string
|
OpeningTag string
|
||||||
ClosingTag string
|
ClosingTag string
|
||||||
@@ -56,7 +54,7 @@ type ThinkingParser struct {
|
|||||||
// AddContent returns the thinking content and the non-thinking content that
|
// AddContent returns the thinking content and the non-thinking content that
|
||||||
// should be immediately sent to the user. It will internally buffer if it needs
|
// should be immediately sent to the user. It will internally buffer if it needs
|
||||||
// to see more raw content to disambiguate
|
// to see more raw content to disambiguate
|
||||||
func (s *ThinkingParser) AddContent(content string) (string, string) {
|
func (s *Parser) AddContent(content string) (string, string) {
|
||||||
s.acc.WriteString(content)
|
s.acc.WriteString(content)
|
||||||
|
|
||||||
var thinkingSb, remainingSb strings.Builder
|
var thinkingSb, remainingSb strings.Builder
|
||||||
@@ -76,7 +74,7 @@ func (s *ThinkingParser) AddContent(content string) (string, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// the additional bool return is true iff we should continue eating
|
// the additional bool return is true iff we should continue eating
|
||||||
func eat(s *ThinkingParser) (string, string, bool) {
|
func eat(s *Parser) (string, string, bool) {
|
||||||
switch s.state {
|
switch s.state {
|
||||||
case thinkingState_LookingForOpening:
|
case thinkingState_LookingForOpening:
|
||||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||||
@@ -171,130 +169,3 @@ func overlap(s, delim string) int {
|
|||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
|
|
||||||
if n == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
shouldContinue := enterFn(n)
|
|
||||||
if !shouldContinue {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
switch x := n.(type) {
|
|
||||||
case *parse.ListNode:
|
|
||||||
for _, c := range x.Nodes {
|
|
||||||
templateVisit(c, enterFn, exitFn)
|
|
||||||
}
|
|
||||||
case *parse.BranchNode:
|
|
||||||
if x.Pipe != nil {
|
|
||||||
templateVisit(x.Pipe, enterFn, exitFn)
|
|
||||||
}
|
|
||||||
if x.List != nil {
|
|
||||||
templateVisit(x.List, enterFn, exitFn)
|
|
||||||
}
|
|
||||||
if x.ElseList != nil {
|
|
||||||
templateVisit(x.ElseList, enterFn, exitFn)
|
|
||||||
}
|
|
||||||
case *parse.ActionNode:
|
|
||||||
templateVisit(x.Pipe, enterFn, exitFn)
|
|
||||||
case *parse.WithNode:
|
|
||||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
|
||||||
case *parse.RangeNode:
|
|
||||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
|
||||||
case *parse.IfNode:
|
|
||||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
|
||||||
case *parse.TemplateNode:
|
|
||||||
templateVisit(x.Pipe, enterFn, exitFn)
|
|
||||||
case *parse.PipeNode:
|
|
||||||
for _, c := range x.Cmds {
|
|
||||||
templateVisit(c, enterFn, exitFn)
|
|
||||||
}
|
|
||||||
case *parse.CommandNode:
|
|
||||||
for _, a := range x.Args {
|
|
||||||
templateVisit(a, enterFn, exitFn)
|
|
||||||
}
|
|
||||||
// text, field, number, etc. are leaves – nothing to recurse into
|
|
||||||
}
|
|
||||||
if exitFn != nil {
|
|
||||||
exitFn(n)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We use a heuristic to infer the tags that surround thinking traces:
|
|
||||||
// We look for a range node that iterates over "Messages" and then look for a
|
|
||||||
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
|
|
||||||
// ListNode and take the first and last TextNodes as the opening and closing
|
|
||||||
// tags.
|
|
||||||
func inferThinkingTags(t *template.Template) (string, string) {
|
|
||||||
ancestors := []parse.Node{}
|
|
||||||
|
|
||||||
openingTag := ""
|
|
||||||
closingTag := ""
|
|
||||||
|
|
||||||
enterFn := func(n parse.Node) bool {
|
|
||||||
ancestors = append(ancestors, n)
|
|
||||||
|
|
||||||
switch x := n.(type) {
|
|
||||||
case *parse.FieldNode:
|
|
||||||
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
|
|
||||||
var mostRecentRange *parse.RangeNode
|
|
||||||
for i := len(ancestors) - 1; i >= 0; i-- {
|
|
||||||
if r, ok := ancestors[i].(*parse.RangeNode); ok {
|
|
||||||
mostRecentRange = r
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(drifkin): to be more robust, check that it's in the action
|
|
||||||
// part, not the `if`'s pipeline part. We do match on the nearest list
|
|
||||||
// that starts and ends with text nodes, which makes this not strictly
|
|
||||||
// necessary for our heuristic
|
|
||||||
|
|
||||||
// go up to the nearest ancestor that is a *parse.ListNode
|
|
||||||
for i := len(ancestors) - 1; i >= 0; i-- {
|
|
||||||
if l, ok := ancestors[i].(*parse.ListNode); ok {
|
|
||||||
firstNode := l.Nodes[0]
|
|
||||||
if t, ok := firstNode.(*parse.TextNode); ok {
|
|
||||||
openingTag = strings.TrimSpace(t.String())
|
|
||||||
}
|
|
||||||
lastNode := l.Nodes[len(l.Nodes)-1]
|
|
||||||
if t, ok := lastNode.(*parse.TextNode); ok {
|
|
||||||
closingTag = strings.TrimSpace(t.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
exitFn := func(n parse.Node) {
|
|
||||||
ancestors = ancestors[:len(ancestors)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
templateVisit(t.Root, enterFn, exitFn)
|
|
||||||
|
|
||||||
return openingTag, closingTag
|
|
||||||
}
|
|
||||||
|
|
||||||
// checks to see if the given field name is present in the pipeline of the given range node
|
|
||||||
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
|
|
||||||
found := false
|
|
||||||
enterFn := func(n parse.Node) bool {
|
|
||||||
switch x := n.(type) {
|
|
||||||
case *parse.FieldNode:
|
|
||||||
if x.Ident[0] == field {
|
|
||||||
found = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
|
|
||||||
return found
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
package server
|
package thinking
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
"text/template"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractThinking(t *testing.T) {
|
func TestExtractThinking(t *testing.T) {
|
||||||
@@ -26,7 +25,7 @@ func TestExtractThinking(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
parser := ThinkingParser{
|
parser := Parser{
|
||||||
OpeningTag: "<think>",
|
OpeningTag: "<think>",
|
||||||
ClosingTag: "</think>",
|
ClosingTag: "</think>",
|
||||||
}
|
}
|
||||||
@@ -259,7 +258,7 @@ func TestThinkingStreaming(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range cases {
|
for _, c := range cases {
|
||||||
parser := ThinkingParser{
|
parser := Parser{
|
||||||
OpeningTag: "<think>",
|
OpeningTag: "<think>",
|
||||||
ClosingTag: "</think>",
|
ClosingTag: "</think>",
|
||||||
}
|
}
|
||||||
@@ -277,127 +276,3 @@ func TestThinkingStreaming(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInferThinkingTags(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
desc string
|
|
||||||
tmplString string
|
|
||||||
wantOpeningTag string
|
|
||||||
wantClosingTag string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "basic",
|
|
||||||
tmplString: `
|
|
||||||
{{ if .Thinking}}
|
|
||||||
/think
|
|
||||||
{{ end }}
|
|
||||||
{{- range $i, $_ := .Messages }}
|
|
||||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
|
||||||
{{ if and $last .Thinking }}
|
|
||||||
<think>{{ .Thinking }}</think>
|
|
||||||
{{ end }}
|
|
||||||
{{ end }}
|
|
||||||
`,
|
|
||||||
wantOpeningTag: "<think>",
|
|
||||||
wantClosingTag: "</think>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "doubly nested range",
|
|
||||||
tmplString: `
|
|
||||||
{{ if .Thinking}}
|
|
||||||
/think
|
|
||||||
{{ end }}
|
|
||||||
{{- range $i, $_ := .Messages }}
|
|
||||||
{{- range $j, $_ := .NotMessages }}
|
|
||||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
|
||||||
{{ if and $last .Thinking }}
|
|
||||||
<think>{{ .Thinking }}</think>
|
|
||||||
{{ end }}
|
|
||||||
{{ end }}
|
|
||||||
{{ end }}
|
|
||||||
`,
|
|
||||||
wantOpeningTag: "",
|
|
||||||
wantClosingTag: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "whitespace is trimmed",
|
|
||||||
tmplString: `
|
|
||||||
{{ if .Thinking}}
|
|
||||||
/think
|
|
||||||
{{ end }}
|
|
||||||
{{- range $i, $_ := .Messages }}
|
|
||||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
|
||||||
{{ if and $last .Thinking }}
|
|
||||||
Some text before {{ .Thinking }} Some text after
|
|
||||||
{{ end }}
|
|
||||||
{{ end }}
|
|
||||||
`,
|
|
||||||
wantOpeningTag: "Some text before",
|
|
||||||
wantClosingTag: "Some text after",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "qwen3",
|
|
||||||
tmplString: `
|
|
||||||
{{- if or .System .Tools .Thinking }}<|im_start|>system
|
|
||||||
{{- if .System }}
|
|
||||||
{{ .System }}
|
|
||||||
{{- end }}
|
|
||||||
{{- if .Tools }}
|
|
||||||
|
|
||||||
# Tools
|
|
||||||
|
|
||||||
You may call one or more functions to assist with the user query.
|
|
||||||
|
|
||||||
You are provided with function signatures within <tools></tools> XML tags:
|
|
||||||
<tools>
|
|
||||||
{{- range .Tools }}
|
|
||||||
{"type": "function", "function": {{ .Function }}}
|
|
||||||
{{- end }}
|
|
||||||
</tools>
|
|
||||||
|
|
||||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
|
||||||
<tool_call>
|
|
||||||
{"name": <function-name>, "arguments": <args-json-object>}
|
|
||||||
</tool_call>
|
|
||||||
{{- end }}
|
|
||||||
{{- if .Thinking }}
|
|
||||||
/think
|
|
||||||
{{- else }}
|
|
||||||
/no_think
|
|
||||||
{{- end }}<|im_end|>
|
|
||||||
{{ end }}
|
|
||||||
{{- range $i, $_ := .Messages }}
|
|
||||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
|
||||||
{{- if eq .Role "user" }}<|im_start|>user
|
|
||||||
{{ .Content }}<|im_end|>
|
|
||||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
|
||||||
{{ if and $last .Thinking }}
|
|
||||||
<think>{{ .Thinking }}</think>
|
|
||||||
{{ end }}
|
|
||||||
{{ if .Content }}{{ .Content }}
|
|
||||||
{{- else if .ToolCalls }}<tool_call>
|
|
||||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
|
||||||
{{ end }}</tool_call>
|
|
||||||
{{- end }}{{ if not $last }}<|im_end|>
|
|
||||||
{{ end }}
|
|
||||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
|
||||||
<tool_response>
|
|
||||||
{{ .Content }}
|
|
||||||
</tool_response><|im_end|>
|
|
||||||
{{ end }}
|
|
||||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
|
||||||
{{ end }}
|
|
||||||
{{- end }}
|
|
||||||
`,
|
|
||||||
wantOpeningTag: "<think>",
|
|
||||||
wantClosingTag: "</think>",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, c := range cases {
|
|
||||||
tmpl := template.Must(template.New("test").Parse(c.tmplString))
|
|
||||||
openingTag, closingTag := inferThinkingTags(tmpl)
|
|
||||||
if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag {
|
|
||||||
t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
134
thinking/template.go
Normal file
134
thinking/template.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package thinking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
"text/template/parse"
|
||||||
|
)
|
||||||
|
|
||||||
|
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
|
||||||
|
if n == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
shouldContinue := enterFn(n)
|
||||||
|
if !shouldContinue {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch x := n.(type) {
|
||||||
|
case *parse.ListNode:
|
||||||
|
for _, c := range x.Nodes {
|
||||||
|
templateVisit(c, enterFn, exitFn)
|
||||||
|
}
|
||||||
|
case *parse.BranchNode:
|
||||||
|
if x.Pipe != nil {
|
||||||
|
templateVisit(x.Pipe, enterFn, exitFn)
|
||||||
|
}
|
||||||
|
if x.List != nil {
|
||||||
|
templateVisit(x.List, enterFn, exitFn)
|
||||||
|
}
|
||||||
|
if x.ElseList != nil {
|
||||||
|
templateVisit(x.ElseList, enterFn, exitFn)
|
||||||
|
}
|
||||||
|
case *parse.ActionNode:
|
||||||
|
templateVisit(x.Pipe, enterFn, exitFn)
|
||||||
|
case *parse.WithNode:
|
||||||
|
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||||
|
case *parse.RangeNode:
|
||||||
|
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||||
|
case *parse.IfNode:
|
||||||
|
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||||
|
case *parse.TemplateNode:
|
||||||
|
templateVisit(x.Pipe, enterFn, exitFn)
|
||||||
|
case *parse.PipeNode:
|
||||||
|
for _, c := range x.Cmds {
|
||||||
|
templateVisit(c, enterFn, exitFn)
|
||||||
|
}
|
||||||
|
case *parse.CommandNode:
|
||||||
|
for _, a := range x.Args {
|
||||||
|
templateVisit(a, enterFn, exitFn)
|
||||||
|
}
|
||||||
|
// text, field, number, etc. are leaves – nothing to recurse into
|
||||||
|
}
|
||||||
|
if exitFn != nil {
|
||||||
|
exitFn(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InferTags uses a heuristic to infer the tags that surround thinking traces:
|
||||||
|
// We look for a range node that iterates over "Messages" and then look for a
|
||||||
|
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
|
||||||
|
// ListNode and take the first and last TextNodes as the opening and closing
|
||||||
|
// tags.
|
||||||
|
func InferTags(t *template.Template) (string, string) {
|
||||||
|
ancestors := []parse.Node{}
|
||||||
|
|
||||||
|
openingTag := ""
|
||||||
|
closingTag := ""
|
||||||
|
|
||||||
|
enterFn := func(n parse.Node) bool {
|
||||||
|
ancestors = append(ancestors, n)
|
||||||
|
|
||||||
|
switch x := n.(type) {
|
||||||
|
case *parse.FieldNode:
|
||||||
|
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
|
||||||
|
var mostRecentRange *parse.RangeNode
|
||||||
|
for i := len(ancestors) - 1; i >= 0; i-- {
|
||||||
|
if r, ok := ancestors[i].(*parse.RangeNode); ok {
|
||||||
|
mostRecentRange = r
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(drifkin): to be more robust, check that it's in the action
|
||||||
|
// part, not the `if`'s pipeline part. We do match on the nearest list
|
||||||
|
// that starts and ends with text nodes, which makes this not strictly
|
||||||
|
// necessary for our heuristic
|
||||||
|
|
||||||
|
// go up to the nearest ancestor that is a *parse.ListNode
|
||||||
|
for i := len(ancestors) - 1; i >= 0; i-- {
|
||||||
|
if l, ok := ancestors[i].(*parse.ListNode); ok {
|
||||||
|
firstNode := l.Nodes[0]
|
||||||
|
if t, ok := firstNode.(*parse.TextNode); ok {
|
||||||
|
openingTag = strings.TrimSpace(t.String())
|
||||||
|
}
|
||||||
|
lastNode := l.Nodes[len(l.Nodes)-1]
|
||||||
|
if t, ok := lastNode.(*parse.TextNode); ok {
|
||||||
|
closingTag = strings.TrimSpace(t.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
exitFn := func(n parse.Node) {
|
||||||
|
ancestors = ancestors[:len(ancestors)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
templateVisit(t.Root, enterFn, exitFn)
|
||||||
|
|
||||||
|
return openingTag, closingTag
|
||||||
|
}
|
||||||
|
|
||||||
|
// checks to see if the given field name is present in the pipeline of the given range node
|
||||||
|
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
|
||||||
|
found := false
|
||||||
|
enterFn := func(n parse.Node) bool {
|
||||||
|
switch x := n.(type) {
|
||||||
|
case *parse.FieldNode:
|
||||||
|
if x.Ident[0] == field {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
|
||||||
|
return found
|
||||||
|
}
|
||||||
130
thinking/template_test.go
Normal file
130
thinking/template_test.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
package thinking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"text/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInferThinkingTags(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
tmplString string
|
||||||
|
wantOpeningTag string
|
||||||
|
wantClosingTag string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "basic",
|
||||||
|
tmplString: `
|
||||||
|
{{ if .Thinking}}
|
||||||
|
/think
|
||||||
|
{{ end }}
|
||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||||
|
{{ if and $last .Thinking }}
|
||||||
|
<think>{{ .Thinking }}</think>
|
||||||
|
{{ end }}
|
||||||
|
{{ end }}
|
||||||
|
`,
|
||||||
|
wantOpeningTag: "<think>",
|
||||||
|
wantClosingTag: "</think>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "doubly nested range",
|
||||||
|
tmplString: `
|
||||||
|
{{ if .Thinking}}
|
||||||
|
/think
|
||||||
|
{{ end }}
|
||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- range $j, $_ := .NotMessages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||||
|
{{ if and $last .Thinking }}
|
||||||
|
<think>{{ .Thinking }}</think>
|
||||||
|
{{ end }}
|
||||||
|
{{ end }}
|
||||||
|
{{ end }}
|
||||||
|
`,
|
||||||
|
wantOpeningTag: "",
|
||||||
|
wantClosingTag: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "whitespace is trimmed",
|
||||||
|
tmplString: `
|
||||||
|
{{ if .Thinking}}
|
||||||
|
/think
|
||||||
|
{{ end }}
|
||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||||
|
{{ if and $last .Thinking }}
|
||||||
|
Some text before {{ .Thinking }} Some text after
|
||||||
|
{{ end }}
|
||||||
|
{{ end }}
|
||||||
|
`,
|
||||||
|
wantOpeningTag: "Some text before",
|
||||||
|
wantClosingTag: "Some text after",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "qwen3",
|
||||||
|
tmplString: `
|
||||||
|
{{- if or .System .Tools .Thinking }}<|im_start|>system
|
||||||
|
{{- if .System }}
|
||||||
|
{{ .System }}
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Tools }}
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
You may call one or more functions to assist with the user query.
|
||||||
|
|
||||||
|
You are provided with function signatures within <tools></tools> XML tags:
|
||||||
|
<tools>
|
||||||
|
{{- range .Tools }}
|
||||||
|
{"type": "function", "function": {{ .Function }}}
|
||||||
|
{{- end }}
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>, "arguments": <args-json-object>}
|
||||||
|
</tool_call>
|
||||||
|
{{- end }}
|
||||||
|
{{- if .Thinking }}
|
||||||
|
/think
|
||||||
|
{{- else }}
|
||||||
|
/no_think
|
||||||
|
{{- end }}<|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- range $i, $_ := .Messages }}
|
||||||
|
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||||
|
{{- if eq .Role "user" }}<|im_start|>user
|
||||||
|
{{ .Content }}<|im_end|>
|
||||||
|
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||||
|
{{ if and $last .Thinking }}
|
||||||
|
<think>{{ .Thinking }}</think>
|
||||||
|
{{ end }}
|
||||||
|
{{ if .Content }}{{ .Content }}
|
||||||
|
{{- else if .ToolCalls }}<tool_call>
|
||||||
|
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||||
|
{{ end }}</tool_call>
|
||||||
|
{{- end }}{{ if not $last }}<|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||||
|
<tool_response>
|
||||||
|
{{ .Content }}
|
||||||
|
</tool_response><|im_end|>
|
||||||
|
{{ end }}
|
||||||
|
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||||
|
{{ end }}
|
||||||
|
{{- end }}
|
||||||
|
`,
|
||||||
|
wantOpeningTag: "<think>",
|
||||||
|
wantClosingTag: "</think>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
tmpl := template.Must(template.New("test").Parse(c.tmplString))
|
||||||
|
openingTag, closingTag := InferTags(tmpl)
|
||||||
|
if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag {
|
||||||
|
t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user