diff --git a/api/types.go b/api/types.go index 602f93da88..94d492006a 100644 --- a/api/types.go +++ b/api/types.go @@ -83,6 +83,12 @@ type GenerateRequest struct { // Options lists model-specific options. For example, temperature can be // set through this field, if the model supports it. Options map[string]any `json:"options"` + + // Think controls whether thinking/reasoning models will think before + // responding. Needs to be a pointer so we can distinguish between false + // (request that thinking _not_ be used) and unset (use the old behavior + // before this option was introduced) + Think *bool `json:"think,omitempty"` } // ChatRequest describes a request sent by [Client.Chat]. @@ -108,6 +114,10 @@ type ChatRequest struct { // Options lists model-specific options. Options map[string]any `json:"options"` + + // Think controls whether thinking/reasoning models will think before + // responding + Think *bool `json:"think,omitempty"` } type Tools []Tool @@ -126,8 +136,11 @@ func (t Tool) String() string { // role ("system", "user", or "assistant"), the content and an optional list // of images. type Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + // Thinking contains the text that was inside thinking tags in the + // original model output when ChatRequest.Think is enabled. + Thinking string `json:"thinking,omitempty"` Images []ImageData `json:"images,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` } @@ -478,6 +491,10 @@ type GenerateResponse struct { // Response is the textual response itself. Response string `json:"response"` + // Thinking contains the text that was inside thinking tags in the + // original model output when ChatRequest.Think is enabled. + Thinking string `json:"thinking,omitempty"` + // Done specifies if the response is complete. Done bool `json:"done"` diff --git a/api/types_test.go b/api/types_test.go index 1a6fc811cc..9c2fb1f119 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -372,3 +372,50 @@ func TestPropertyType_MarshalJSON(t *testing.T) { }) } } + +func TestThinking_UnmarshalJSON(t *testing.T) { + trueVal := true + falseVal := false + + tests := []struct { + name string + input string + expectedThinking *bool + expectedError bool + }{ + { + name: "true", + input: `{ "think": true }`, + expectedThinking: &trueVal, + }, + { + name: "false", + input: `{ "think": false }`, + expectedThinking: &falseVal, + }, + { + name: "unset", + input: `{ }`, + expectedThinking: nil, + }, + { + name: "invalid", + input: `{ "think": "true" }`, + expectedThinking: nil, + expectedError: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var req GenerateRequest + err := json.Unmarshal([]byte(test.input), &req) + if test.expectedError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, test.expectedThinking, req.Think) + } + }) + } +} diff --git a/cmd/cmd.go b/cmd/cmd.go index b9047529d4..2d16537906 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -39,6 +39,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" + "github.com/ollama/ollama/readline" "github.com/ollama/ollama/runner" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/model" @@ -46,6 +47,23 @@ import ( "github.com/ollama/ollama/version" ) +// ensureThinkingSupport emits a warning if the model does not advertise thinking support +func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { + if name == "" { + return + } + resp, err := client.Show(ctx, &api.ShowRequest{Model: name}) + if err != nil { + return + } + for _, cap := range resp.Capabilities { + if cap == model.CapabilityThinking { + return + } + } + fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name) +} + var errModelfileNotFound = errors.New("specified Modelfile wasn't found") func getModelfileName(cmd *cobra.Command) (string, error) { @@ -265,6 +283,9 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error { req := &api.GenerateRequest{ Model: opts.Model, KeepAlive: opts.KeepAlive, + + // pass Think here so we fail before getting to the chat prompt if the model doesn't support it + Think: opts.Think, } return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) @@ -299,6 +320,22 @@ func RunHandler(cmd *cobra.Command, args []string) error { } opts.Format = format + thinkFlag := cmd.Flags().Lookup("think") + if thinkFlag.Changed { + think, err := cmd.Flags().GetBool("think") + if err != nil { + return err + } + opts.Think = &think + } else { + opts.Think = nil + } + hidethinking, err := cmd.Flags().GetBool("hidethinking") + if err != nil { + return err + } + opts.HideThinking = hidethinking + keepAlive, err := cmd.Flags().GetString("keepalive") if err != nil { return err @@ -362,6 +399,11 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } + opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed) + if err != nil { + return err + } + opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) // TODO: remove the projector info and vision info checks below, @@ -923,17 +965,19 @@ func PullHandler(cmd *cobra.Command, args []string) error { type generateContextKey string type runOptions struct { - Model string - ParentModel string - Prompt string - Messages []api.Message - WordWrap bool - Format string - System string - Images []api.ImageData - Options map[string]any - MultiModal bool - KeepAlive *api.Duration + Model string + ParentModel string + Prompt string + Messages []api.Message + WordWrap bool + Format string + System string + Images []api.ImageData + Options map[string]any + MultiModal bool + KeepAlive *api.Duration + Think *bool + HideThinking bool } type displayResponseState struct { @@ -989,6 +1033,26 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState) } } +func thinkingOutputOpeningText(plainText bool) string { + text := "Thinking...\n" + + if plainText { + return text + } + + return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey +} + +func thinkingOutputClosingText(plainText bool) string { + text := "...done thinking.\n\n" + + if plainText { + return text + } + + return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault +} + func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { client, err := api.ClientFromEnvironment() if err != nil { @@ -1016,14 +1080,34 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { var latest api.ChatResponse var fullResponse strings.Builder var role string + var thinkTagOpened bool = false + var thinkTagClosed bool = false fn := func(response api.ChatResponse) error { - p.StopAndClear() + if response.Message.Content != "" || !opts.HideThinking { + p.StopAndClear() + } latest = response role = response.Message.Role + if response.Message.Thinking != "" && !opts.HideThinking { + if !thinkTagOpened { + fmt.Print(thinkingOutputOpeningText(false)) + thinkTagOpened = true + } + displayResponse(response.Message.Thinking, opts.WordWrap, state) + } + content := response.Message.Content + if thinkTagOpened && !thinkTagClosed && content != "" { + fmt.Print(thinkingOutputClosingText(false)) + thinkTagClosed = true + } + // purposefully not putting thinking blocks in the response, which would + // only be needed if we later added tool calling to the cli (they get + // filtered out anyway since current models don't expect them unless you're + // about to finish some tool calls) fullResponse.WriteString(content) displayResponse(content, opts.WordWrap, state) @@ -1040,6 +1124,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { Messages: opts.Messages, Format: json.RawMessage(opts.Format), Options: opts.Options, + Think: opts.Think, } if opts.KeepAlive != nil { @@ -1101,13 +1186,32 @@ func generate(cmd *cobra.Command, opts runOptions) error { }() var state *displayResponseState = &displayResponseState{} + var thinkTagOpened bool = false + var thinkTagClosed bool = false + + plainText := !term.IsTerminal(int(os.Stdout.Fd())) fn := func(response api.GenerateResponse) error { - p.StopAndClear() - latest = response content := response.Response + if response.Response != "" || !opts.HideThinking { + p.StopAndClear() + } + + if response.Thinking != "" && !opts.HideThinking { + if !thinkTagOpened { + fmt.Print(thinkingOutputOpeningText(plainText)) + thinkTagOpened = true + } + displayResponse(response.Thinking, opts.WordWrap, state) + } + + if thinkTagOpened && !thinkTagClosed && content != "" { + fmt.Print(thinkingOutputClosingText(plainText)) + thinkTagClosed = true + } + displayResponse(content, opts.WordWrap, state) return nil @@ -1133,6 +1237,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { System: opts.System, Options: opts.Options, KeepAlive: opts.KeepAlive, + Think: opts.Think, } if err := client.Generate(ctx, &request, fn); err != nil { @@ -1348,6 +1453,8 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().String("format", "", "Response format (e.g. json)") + runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models") + runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)") stopCmd := &cobra.Command{ Use: "stop MODEL", @@ -1399,7 +1506,6 @@ func NewCLI() *cobra.Command { PreRunE: checkServerHeartbeat, RunE: ListRunningHandler, } - copyCmd := &cobra.Command{ Use: "cp SOURCE DESTINATION", Short: "Copy a model", @@ -1488,3 +1594,45 @@ func NewCLI() *cobra.Command { return rootCmd } + +// If the user has explicitly set thinking options, either through the CLI or +// through the `/set think` or `set nothink` interactive options, then we +// respect them. Otherwise, we check model capabilities to see if the model +// supports thinking. If the model does support thinking, we enable it. +// Otherwise, we unset the thinking option (which is different than setting it +// to false). +// +// If capabilities are not provided, we fetch them from the server. +func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) { + if explicitlySetByUser { + return runOpts.Think, nil + } + + if caps == nil { + client, err := api.ClientFromEnvironment() + if err != nil { + return nil, err + } + ret, err := client.Show(context.Background(), &api.ShowRequest{ + Model: runOpts.Model, + }) + if err != nil { + return nil, err + } + caps = &ret.Capabilities + } + + thinkingSupported := false + for _, cap := range *caps { + if cap == model.CapabilityThinking { + thinkingSupported = true + } + } + + if thinkingSupported { + thinking := true + return &thinking, nil + } + + return nil, nil +} diff --git a/cmd/interactive.go b/cmd/interactive.go index d7e6fbcfb0..a285b365c0 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " /set noformat Disable formatting") fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats") fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats") + fmt.Fprintln(os.Stderr, " /set think Enable thinking") + fmt.Fprintln(os.Stderr, " /set nothink Disable thinking") fmt.Fprintln(os.Stderr, "") } @@ -128,6 +130,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { var sb strings.Builder var multiline MultilineState + var thinkExplicitlySet bool = opts.Think != nil for { line, err := scanner.Readline() @@ -195,11 +198,19 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { opts.Model = args[1] opts.Messages = []api.Message{} fmt.Printf("Loading model '%s'\n", opts.Model) + opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet) + if err != nil { + return err + } if err := loadOrUnloadModel(cmd, &opts); err != nil { if strings.Contains(err.Error(), "not found") { fmt.Printf("error: %v\n", err) continue } + if strings.Contains(err.Error(), "does not support thinking") { + fmt.Printf("error: %v\n", err) + continue + } return err } continue @@ -260,6 +271,22 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { return err } fmt.Println("Set 'quiet' mode.") + case "think": + think := true + opts.Think = &think + thinkExplicitlySet = true + if client, err := api.ClientFromEnvironment(); err == nil { + ensureThinkingSupport(cmd.Context(), client, opts.Model) + } + fmt.Println("Set 'think' mode.") + case "nothink": + think := false + opts.Think = &think + thinkExplicitlySet = true + if client, err := api.ClientFromEnvironment(); err == nil { + ensureThinkingSupport(cmd.Context(), client, opts.Model) + } + fmt.Println("Set 'nothink' mode.") case "format": if len(args) < 3 || args[2] != "json" { fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") @@ -448,6 +475,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { assistant, err := chat(cmd, opts) if err != nil { + if strings.Contains(err.Error(), "does not support thinking") { + fmt.Printf("error: %v\n", err) + sb.Reset() + continue + } return err } if assistant != nil { diff --git a/cmd/warn_thinking_test.go b/cmd/warn_thinking_test.go new file mode 100644 index 0000000000..31dc4156ba --- /dev/null +++ b/cmd/warn_thinking_test.go @@ -0,0 +1,63 @@ +package cmd + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/types/model" +) + +// Test that a warning is printed when thinking is requested but not supported. +func TestWarnMissingThinking(t *testing.T) { + cases := []struct { + capabilities []model.Capability + expectWarn bool + }{ + {capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false}, + {capabilities: []model.Capability{}, expectWarn: true}, + } + + for _, tc := range cases { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/show" || r.Method != http.MethodPost { + t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method) + } + var req api.ShowRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + resp := api.ShowResponse{Capabilities: tc.capabilities} + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("encode response: %v", err) + } + })) + defer srv.Close() + + t.Setenv("OLLAMA_HOST", srv.URL) + client, err := api.ClientFromEnvironment() + if err != nil { + t.Fatal(err) + } + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + ensureThinkingSupport(t.Context(), client, "m") + w.Close() + os.Stderr = oldStderr + out, _ := io.ReadAll(r) + + warned := strings.Contains(string(out), "warning:") + if tc.expectWarn && !warned { + t.Errorf("expected warning, got none") + } + if !tc.expectWarn && warned { + t.Errorf("did not expect warning, got: %s", string(out)) + } + } +} diff --git a/docs/api.md b/docs/api.md index abd276150c..11eaf73ab8 100644 --- a/docs/api.md +++ b/docs/api.md @@ -43,6 +43,7 @@ Generate a response for a given prompt with a provided model. This is a streamin - `prompt`: the prompt to generate a response for - `suffix`: the text after the model response - `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`) +- `think`: (for thinking models) should the model think before responding? Advanced parameters (optional): @@ -490,11 +491,13 @@ Generate the next message in a chat with a provided model. This is a streaming e - `model`: (required) the [model name](#model-names) - `messages`: the messages of the chat, this can be used to keep a chat memory - `tools`: list of tools in JSON for the model to use if supported +- `think`: (for thinking models) should the model think before responding? The `message` object has the following fields: - `role`: the role of the message, either `system`, `user`, `assistant`, or `tool` - `content`: the content of the message +- `thinking`: (for thinking models) the model's thinking process - `images` (optional): a list of images to include in the message (for multimodal models such as `llava`) - `tool_calls` (optional): a list of tools in JSON that the model wants to use diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index 6bb9a003eb..246d2ba3e8 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -3,6 +3,7 @@ package model import ( "cmp" "context" + "fmt" "iter" "log/slog" "strings" @@ -210,6 +211,14 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { return ids, nil } +type lazyIdsString struct { + ids []int32 +} + +func (l lazyIdsString) LogValue() slog.Value { + return slog.AnyValue(fmt.Sprint(l.ids)) +} + func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids { @@ -234,6 +243,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String()) + slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids}) return sb.String(), nil } diff --git a/readline/types.go b/readline/types.go index e136d9962b..f4efa8d92c 100644 --- a/readline/types.go +++ b/readline/types.go @@ -61,6 +61,8 @@ const ( ColorGrey = Esc + "[38;5;245m" ColorDefault = Esc + "[0m" + ColorBold = Esc + "[1m" + StartBracketedPaste = Esc + "[?2004h" EndBracketedPaste = Esc + "[?2004l" ) diff --git a/server/images.go b/server/images.go index a69e2a9f25..58fb87dccb 100644 --- a/server/images.go +++ b/server/images.go @@ -37,6 +37,7 @@ var ( errCapabilityInsert = errors.New("insert") errCapabilityVision = errors.New("vision") errCapabilityEmbedding = errors.New("embedding") + errCapabilityThinking = errors.New("thinking") errInsecureProtocol = errors.New("insecure protocol http") ) @@ -111,6 +112,12 @@ func (m *Model) Capabilities() []model.Capability { capabilities = append(capabilities, model.CapabilityVision) } + // Check for thinking capability + openingTag, closingTag := inferThinkingTags(m.Template.Template) + if openingTag != "" && closingTag != "" { + capabilities = append(capabilities, model.CapabilityThinking) + } + return capabilities } @@ -127,6 +134,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error { model.CapabilityInsert: errCapabilityInsert, model.CapabilityVision: errCapabilityVision, model.CapabilityEmbedding: errCapabilityEmbedding, + model.CapabilityThinking: errCapabilityThinking, } for _, cap := range want { @@ -141,11 +149,19 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error { } } + var err error if len(errs) > 0 { - return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) + err = fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) } - return nil + if slices.Contains(errs, errCapabilityThinking) { + if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" { + // append a message to the existing error + return fmt.Errorf("%w. Pull the model again to get the latest version with full thinking support", err) + } + } + + return err } func (m *Model) String() string { diff --git a/server/prompt.go b/server/prompt.go index 147a02b69c..f8c895d71c 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -19,7 +19,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error) // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages -func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { +func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent @@ -41,8 +41,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } } + thinkVal := false + if think != nil { + thinkVal = *think + } var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil { + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil { return "", nil, err } @@ -96,7 +100,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. // truncate any messages that do not fit into the context window var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil { + thinkVal := false + if think != nil { + thinkVal = *think + } + if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil { return "", nil, err } diff --git a/server/prompt_test.go b/server/prompt_test.go index fb6c96c0ce..0043b9a479 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -208,7 +208,8 @@ func TestChatPrompt(t *testing.T) { t.Run(tt.name, func(t *testing.T) { model := tt.model opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} - prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) + think := false + prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think) if tt.error == nil && err != nil { t.Fatal(err) } else if tt.error != nil && err != tt.error { diff --git a/server/routes.go b/server/routes.go index 42e8cdd1de..236f92e225 100644 --- a/server/routes.go +++ b/server/routes.go @@ -17,7 +17,6 @@ import ( "net/netip" "os" "os/signal" - "regexp" "slices" "strings" "syscall" @@ -186,6 +185,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { if req.Suffix != "" { caps = append(caps, model.CapabilityInsert) } + if req.Think != nil && *req.Think { + caps = append(caps, model.CapabilityThinking) + // TODO(drifkin): consider adding a warning if it's false and the model + // doesn't support thinking. It's not strictly required, but it can be a + // hint that the user is on an older qwen3/r1 model that doesn't have an + // updated template supporting thinking + } r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { @@ -254,6 +260,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt}) } + values.Think = req.Think != nil && *req.Think + values.IsThinkSet = req.Think != nil + var b bytes.Buffer if req.Context != nil { slog.Warn("the context field is deprecated and will be removed in a future version of Ollama") @@ -273,6 +282,15 @@ func (s *Server) GenerateHandler(c *gin.Context) { prompt = b.String() } + var thinkingState *thinkingParser + openingTag, closingTag := inferThinkingTags(m.Template.Template) + if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" { + thinkingState = &thinkingParser{ + openingTag: openingTag, + closingTag: closingTag, + } + } + ch := make(chan any) go func() { // TODO (jmorganca): avoid building the response twice both here and below @@ -297,6 +315,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } + if thinkingState != nil { + thinking, content := thinkingState.addContent(cr.Content) + res.Thinking = thinking + res.Response = content + } + if _, err := sb.WriteString(cr.Content); err != nil { ch <- gin.H{"error": err.Error()} } @@ -324,11 +348,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { if req.Stream != nil && !*req.Stream { var r api.GenerateResponse - var sb strings.Builder + var sbThinking strings.Builder + var sbContent strings.Builder for rr := range ch { switch t := rr.(type) { case api.GenerateResponse: - sb.WriteString(t.Response) + sbThinking.WriteString(t.Thinking) + sbContent.WriteString(t.Response) r = t case gin.H: msg, ok := t["error"].(string) @@ -344,7 +370,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { } } - r.Response = sb.String() + r.Thinking = sbThinking.String() + r.Response = sbContent.String() + c.JSON(http.StatusOK, r) return } @@ -1436,6 +1464,9 @@ func (s *Server) ChatHandler(c *gin.Context) { if len(req.Tools) > 0 { caps = append(caps, model.CapabilityTools) } + if req.Think != nil && *req.Think { + caps = append(caps, model.CapabilityThinking) + } name := model.ParseName(req.Model) if !name.IsValid() { @@ -1476,13 +1507,22 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think) if err != nil { slog.Error("chat prompt error", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + var thinkingState *thinkingParser + openingTag, closingTag := inferThinkingTags(m.Template.Template) + if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" { + thinkingState = &thinkingParser{ + openingTag: openingTag, + closingTag: closingTag, + } + } + var toolParser *tools.Parser if len(req.Tools) > 0 { toolParser, err = tools.NewParser(m.Template.Template) @@ -1516,6 +1556,16 @@ func (s *Server) ChatHandler(c *gin.Context) { }, } + if thinkingState != nil { + thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content) + if thinkingContent == "" && remainingContent == "" && !r.Done { + // need to accumulate more to decide what to send + return + } + res.Message.Content = remainingContent + res.Message.Thinking = thinkingContent + } + if r.Done { res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) @@ -1523,12 +1573,14 @@ func (s *Server) ChatHandler(c *gin.Context) { } if len(req.Tools) > 0 { - toolCalls, content := toolParser.Add(r.Content) + toolCalls, content := toolParser.Add(res.Message.Content) if len(content) > 0 { res.Message.Content = content } else if len(toolCalls) > 0 { res.Message.ToolCalls = toolCalls res.Message.Content = "" + } else if res.Message.Thinking != "" { + // don't return } else { if r.Done { ch <- res @@ -1536,6 +1588,7 @@ func (s *Server) ChatHandler(c *gin.Context) { return } } + ch <- res }); err != nil { ch <- gin.H{"error": err.Error()} @@ -1544,12 +1597,14 @@ func (s *Server) ChatHandler(c *gin.Context) { if req.Stream != nil && !*req.Stream { var resp api.ChatResponse - var sb strings.Builder var toolCalls []api.ToolCall + var sbThinking strings.Builder + var sbContent strings.Builder for rr := range ch { switch t := rr.(type) { case api.ChatResponse: - sb.WriteString(t.Message.Content) + sbThinking.WriteString(t.Message.Thinking) + sbContent.WriteString(t.Message.Content) resp = t if len(req.Tools) > 0 { toolCalls = append(toolCalls, t.Message.ToolCalls...) @@ -1568,7 +1623,9 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - resp.Message.Content = sb.String() + resp.Message.Content = sbContent.String() + resp.Message.Thinking = sbThinking.String() + if len(toolCalls) > 0 { resp.Message.ToolCalls = toolCalls } @@ -1595,8 +1652,6 @@ func handleScheduleError(c *gin.Context, name string, err error) { } } -var thinkTagRegexp = regexp.MustCompile(`(?s).*?(\n)*`) - func filterThinkTags(msgs []api.Message, m *Model) []api.Message { if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" { finalUserIndex := -1 @@ -1608,7 +1663,17 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { for i, msg := range msgs { if msg.Role == "assistant" && i < finalUserIndex { - msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "") + // TODO(drifkin): this is from before we added proper thinking support. + // However, even if thinking is not enabled (and therefore we shouldn't + // change the user output), we should probably perform this filtering + // for all thinking models (not just qwen3 & deepseek-r1) since it tends + // to save tokens and improve quality. + thinkingState := &thinkingParser{ + openingTag: "", + closingTag: "", + } + _, content := thinkingState.addContent(msg.Content) + msgs[i].Content = content } } } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 6bbf5b1121..75a246fc61 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -143,6 +143,25 @@ func TestGenerateChat(t *testing.T) { } }) + t.Run("missing thinking capability", func(t *testing.T) { + think := true + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + Think: &think, + }) + + if w.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + t.Run("missing model", func(t *testing.T) { w := createRequest(t, s.ChatHandler, api.ChatRequest{}) if w.Code != http.StatusBadRequest { diff --git a/server/thinking.go b/server/thinking.go new file mode 100644 index 0000000000..2213b6b6e9 --- /dev/null +++ b/server/thinking.go @@ -0,0 +1,300 @@ +package server + +import ( + "strings" + "text/template" + "text/template/parse" + "unicode" +) + +type thinkingState int + +const ( + // We're looking for the opening tag, but we haven't seen any non-whitespace + // characters yet + thinkingState_LookingForOpening thinkingState = iota + // We've seen the opening tag, but we haven't seen any non-whitespace + // characters yet (we want to eat any whitespace between the opening tag and + // the thinking content) + thinkingState_ThinkingStartedEatingWhitespace + // We've seen non-whitespace characters after the opening tag, but we haven't + // seen the closing tag yet + thinkingState_Thinking + // We've seen the closing tag, but we haven't seen any non-whitespace + // characters after the closing tag yet (we want to eat any whitespace between + // the closing tag and the content) + thinkingState_ThinkingDoneEatingWhitespace + // We've seen the closing tag and seen at least one non-whitespace character + // after it + thinkingState_ThinkingDone +) + +func (s thinkingState) String() string { + switch s { + case thinkingState_LookingForOpening: + return "LookingForOpening" + case thinkingState_ThinkingStartedEatingWhitespace: + return "ThinkingStartedEatingWhitespace" + case thinkingState_Thinking: + return "Thinking" + case thinkingState_ThinkingDoneEatingWhitespace: + return "ThinkingDoneEatingWhitespace" + case thinkingState_ThinkingDone: + return "ThinkingDone" + default: + return "Unknown" + } +} + +type thinkingParser struct { + state thinkingState + openingTag string + closingTag string + acc strings.Builder +} + +// addContent returns the thinking content and the non-thinking content that +// should be immediately sent to the user. It will internally buffer if it needs +// to see more raw content to disambiguate +func (s *thinkingParser) addContent(content string) (string, string) { + s.acc.WriteString(content) + + var thinkingSb, remainingSb strings.Builder + + var thinking, remaining string + keepLooping := true + // we loop because we might pass through multiple parsing states in a single + // call to addContent, and we want to make sure callers don't have to wait for + // data that's already unambiguous + for keepLooping { + thinking, remaining, keepLooping = eat(s) + thinkingSb.WriteString(thinking) + remainingSb.WriteString(remaining) + } + + return thinkingSb.String(), remainingSb.String() +} + +// the additional bool return is true iff we should continue eating +func eat(s *thinkingParser) (string, string, bool) { + switch s.state { + case thinkingState_LookingForOpening: + trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace) + if strings.HasPrefix(trimmed, s.openingTag) { + after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag) + after = strings.TrimLeftFunc(after, unicode.IsSpace) + // after might contain more than just thinking tokens, so we continue + // parsing instead of returning it as thinking tokens here + s.acc.Reset() + s.acc.WriteString(after) + if after == "" { + s.state = thinkingState_ThinkingStartedEatingWhitespace + } else { + s.state = thinkingState_Thinking + } + return "", "", true + } else if strings.HasPrefix(s.openingTag, trimmed) { + // partial opening seen, so let's keep accumulating + return "", "", false + } else if trimmed == "" { + // saw whitespace only, so let's keep accumulating + return "", "", false + } else { + // didn't see an opening tag, but we have content, so thinking was skipped + s.state = thinkingState_ThinkingDone + // note that we use the original content, not the trimmed one because we + // don't want to eat any whitespace in the real content if there were no + // thinking tags + return "", s.acc.String(), false + } + case thinkingState_ThinkingStartedEatingWhitespace: + trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace) + s.acc.Reset() + if trimmed == "" { + return "", "", false + } else { + s.state = thinkingState_Thinking + s.acc.WriteString(trimmed) + return "", "", true + } + case thinkingState_Thinking: + acc := s.acc.String() + if strings.Contains(acc, s.closingTag) { + split := strings.Split(acc, s.closingTag) + thinking := split[0] + remaining := strings.Join(split[1:], s.closingTag) + remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace) + s.acc.Reset() + if remaining == "" { + s.state = thinkingState_ThinkingDoneEatingWhitespace + } else { + s.state = thinkingState_ThinkingDone + } + return thinking, remaining, false + } else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 { + thinking := acc[:len(acc)-overlapLen] + remaining := acc[len(acc)-overlapLen:] + s.acc.Reset() + // keep track of the candidate closing tag. We have to buffer it until it + // becomes disambiguated + s.acc.WriteString(remaining) + return thinking, "", false + } else { + // purely just thinking tokens, so we can return them + s.acc.Reset() + return acc, "", false + } + case thinkingState_ThinkingDoneEatingWhitespace: + trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace) + s.acc.Reset() + // if we see non-whitespace, we're done eating the leading whitespace of the content + if trimmed != "" { + s.state = thinkingState_ThinkingDone + } + return "", trimmed, false + case thinkingState_ThinkingDone: + acc := s.acc.String() + s.acc.Reset() + return "", acc, false + default: + panic("unknown state") + } +} + +// longest overlap between suffix of s and prefix of delim +func overlap(s, delim string) int { + max := min(len(delim), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, delim[:i]) { + return i + } + } + return 0 +} + +func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) { + if n == nil { + return + } + shouldContinue := enterFn(n) + if !shouldContinue { + return + } + switch x := n.(type) { + case *parse.ListNode: + for _, c := range x.Nodes { + templateVisit(c, enterFn, exitFn) + } + case *parse.BranchNode: + if x.Pipe != nil { + templateVisit(x.Pipe, enterFn, exitFn) + } + if x.List != nil { + templateVisit(x.List, enterFn, exitFn) + } + if x.ElseList != nil { + templateVisit(x.ElseList, enterFn, exitFn) + } + case *parse.ActionNode: + templateVisit(x.Pipe, enterFn, exitFn) + case *parse.WithNode: + templateVisit(&x.BranchNode, enterFn, exitFn) + case *parse.RangeNode: + templateVisit(&x.BranchNode, enterFn, exitFn) + case *parse.IfNode: + templateVisit(&x.BranchNode, enterFn, exitFn) + case *parse.TemplateNode: + templateVisit(x.Pipe, enterFn, exitFn) + case *parse.PipeNode: + for _, c := range x.Cmds { + templateVisit(c, enterFn, exitFn) + } + case *parse.CommandNode: + for _, a := range x.Args { + templateVisit(a, enterFn, exitFn) + } + // text, field, number, etc. are leaves – nothing to recurse into + } + if exitFn != nil { + exitFn(n) + } +} + +// We use a heuristic to infer the tags that surround thinking traces: +// We look for a range node that iterates over "Messages" and then look for a +// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest +// ListNode and take the first and last TextNodes as the opening and closing +// tags. +func inferThinkingTags(t *template.Template) (string, string) { + ancestors := []parse.Node{} + + openingTag := "" + closingTag := "" + + enterFn := func(n parse.Node) bool { + ancestors = append(ancestors, n) + + switch x := n.(type) { + case *parse.FieldNode: + if len(x.Ident) > 0 && x.Ident[0] == "Thinking" { + var mostRecentRange *parse.RangeNode + for i := len(ancestors) - 1; i >= 0; i-- { + if r, ok := ancestors[i].(*parse.RangeNode); ok { + mostRecentRange = r + break + } + } + if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") { + return true + } + + // TODO(drifkin): to be more robust, check that it's in the action + // part, not the `if`'s pipeline part. We do match on the nearest list + // that starts and ends with text nodes, which makes this not strictly + // necessary for our heuristic + + // go up to the nearest ancestor that is a *parse.ListNode + for i := len(ancestors) - 1; i >= 0; i-- { + if l, ok := ancestors[i].(*parse.ListNode); ok { + firstNode := l.Nodes[0] + if t, ok := firstNode.(*parse.TextNode); ok { + openingTag = strings.TrimSpace(t.String()) + } + lastNode := l.Nodes[len(l.Nodes)-1] + if t, ok := lastNode.(*parse.TextNode); ok { + closingTag = strings.TrimSpace(t.String()) + } + + break + } + } + } + } + + return true + } + + exitFn := func(n parse.Node) { + ancestors = ancestors[:len(ancestors)-1] + } + + templateVisit(t.Root, enterFn, exitFn) + + return openingTag, closingTag +} + +// checks to see if the given field name is present in the pipeline of the given range node +func rangeUsesField(rangeNode *parse.RangeNode, field string) bool { + found := false + enterFn := func(n parse.Node) bool { + switch x := n.(type) { + case *parse.FieldNode: + if x.Ident[0] == field { + found = true + } + } + return true + } + templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil) + return found +} diff --git a/server/thinking_test.go b/server/thinking_test.go new file mode 100644 index 0000000000..a2055635ee --- /dev/null +++ b/server/thinking_test.go @@ -0,0 +1,403 @@ +package server + +import ( + "testing" + "text/template" +) + +func TestExtractThinking(t *testing.T) { + tests := []struct { + in, wantContent, wantThink string + }{ + { + in: " internal world", + wantThink: "internal ", + wantContent: "world", + }, + { + in: "abc", + wantThink: "a", + wantContent: "bc", + }, + { + in: "no think", + wantThink: "", + wantContent: "no think", + }, + } + for i, tt := range tests { + parser := thinkingParser{ + openingTag: "", + closingTag: "", + } + gotThinking, gotContent := parser.addContent(tt.in) + if gotContent != tt.wantContent || gotThinking != tt.wantThink { + t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent) + } + } +} + +func TestThinkingStreaming(t *testing.T) { + type step struct { + input string + wantThinking string + wantContent string + wantStateAfter thinkingState + } + + cases := []struct { + desc string + skip bool + steps []step + }{ + { + desc: "content without a thinking tag", + steps: []step{ + { + input: " abc", + wantThinking: "", + wantContent: " abc", + wantStateAfter: thinkingState_ThinkingDone, + }, + }, + }, + { + desc: "content before a thinking tag nerfs the thinking tag", + steps: []step{ + { + input: " abc def ghi", + wantThinking: "", + wantContent: " abc def ghi", + wantStateAfter: thinkingState_ThinkingDone, + }, + }, + }, + { + desc: "building up a thinking tag partially", + steps: []step{ + { + input: " a", + wantThinking: "a", + wantContent: "", + wantStateAfter: thinkingState_Thinking, + }, + }, + }, + { + desc: "partial closing tag", + steps: []step{ + { + input: "abcdef", + wantThinking: "", + wantContent: "def", + wantStateAfter: thinkingState_ThinkingDone, + }, + }, + }, + { + desc: "partial closing tag fakeout", + steps: []step{ + { + input: "abcdef", + wantThinking: "def", + wantContent: "", + wantStateAfter: thinkingState_Thinking, + }, + { + input: "ghijkl", + wantThinking: "", + wantContent: "jkl", + wantStateAfter: thinkingState_ThinkingDone, + }, + }, + }, + { + desc: "whitespace after thinking tag", + steps: []step{ + { + input: " abc\n\ndef", + wantThinking: "abc", + wantContent: "def", + wantStateAfter: thinkingState_ThinkingDone, + }, + }, + }, + { + desc: "whitespace after thinking tag (incremental)", + steps: []step{ + { + input: " abc", + wantThinking: "abc", + wantContent: "", + wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace, + }, + { + input: "\n\ndef", + wantThinking: "", + wantContent: "def", + wantStateAfter: thinkingState_ThinkingDone, + }, + }, + }, + { + desc: "whitespace after thinking tag with content and more whitespace", + steps: []step{ + { + input: " abc\n\ndef ", + wantThinking: "abc", + wantContent: "def ", + wantStateAfter: thinkingState_ThinkingDone, + }, + { + input: " ghi", + wantThinking: "", + wantContent: " ghi", + wantStateAfter: thinkingState_ThinkingDone, + }, + }, + }, + { + desc: "token by token", + steps: []step{ + { + input: "", + wantThinking: "", + wantContent: "", + wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace, + }, + { + input: "\n", + wantThinking: "", + wantContent: "", + wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace, + }, + { + input: "", + wantThinking: "", + wantContent: "", + wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace, + }, + { + input: "\n\n", + wantThinking: "", + wantContent: "", + wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace, + }, + { + input: "Hi", + wantThinking: "", + wantContent: "Hi", + wantStateAfter: thinkingState_ThinkingDone, + }, + { + input: " there", + wantThinking: "", + wantContent: " there", + wantStateAfter: thinkingState_ThinkingDone, + }, + }, + }, + { + desc: "leading thinking whitespace", + steps: []step{ + { + input: " \t ", + wantThinking: "", + wantContent: "", + wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace, + }, + { + input: " these are some ", + wantThinking: "these are some ", + wantContent: "", + wantStateAfter: thinkingState_Thinking, + }, + { + input: "thoughts ", + wantThinking: "thoughts ", + wantContent: "", + wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace, + }, + { + input: " more content", + wantThinking: "", + wantContent: "more content", + wantStateAfter: thinkingState_ThinkingDone, + }, + }, + }, + } + + for _, c := range cases { + parser := thinkingParser{ + openingTag: "", + closingTag: "", + } + if c.skip { + continue + } + for i, step := range c.steps { + thinking, content := parser.addContent(step.input) + if content != step.wantContent || thinking != step.wantThinking { + t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking) + } + if parser.state != step.wantStateAfter { + t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state, step.wantStateAfter) + } + } + } +} + +func TestInferThinkingTags(t *testing.T) { + cases := []struct { + desc string + tmplString string + wantOpeningTag string + wantClosingTag string + }{ + { + desc: "basic", + tmplString: ` + {{ if .Thinking}} + /think + {{ end }} + {{- range $i, $_ := .Messages }} + {{- $last := eq (len (slice $.Messages $i)) 1 -}} + {{ if and $last .Thinking }} + {{ .Thinking }} + {{ end }} + {{ end }} + `, + wantOpeningTag: "", + wantClosingTag: "", + }, + { + desc: "doubly nested range", + tmplString: ` + {{ if .Thinking}} + /think + {{ end }} + {{- range $i, $_ := .Messages }} + {{- range $j, $_ := .NotMessages }} + {{- $last := eq (len (slice $.Messages $i)) 1 -}} + {{ if and $last .Thinking }} + {{ .Thinking }} + {{ end }} + {{ end }} + {{ end }} + `, + wantOpeningTag: "", + wantClosingTag: "", + }, + { + desc: "whitespace is trimmed", + tmplString: ` + {{ if .Thinking}} + /think + {{ end }} + {{- range $i, $_ := .Messages }} + {{- $last := eq (len (slice $.Messages $i)) 1 -}} + {{ if and $last .Thinking }} + Some text before {{ .Thinking }} Some text after + {{ end }} + {{ end }} + `, + wantOpeningTag: "Some text before", + wantClosingTag: "Some text after", + }, + { + desc: "qwen3", + tmplString: ` +{{- if or .System .Tools .Thinking }}<|im_start|>system +{{- if .System }} +{{ .System }} +{{- end }} +{{- if .Tools }} + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{{- range .Tools }} +{"type": "function", "function": {{ .Function }}} +{{- end }} + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } + +{{- end }} +{{- if .Thinking }} +/think +{{- else }} +/no_think +{{- end }}<|im_end|> +{{ end }} +{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 -}} +{{- if eq .Role "user" }}<|im_start|>user +{{ .Content }}<|im_end|> +{{ else if eq .Role "assistant" }}<|im_start|>assistant +{{ if and $last .Thinking }} +{{ .Thinking }} +{{ end }} +{{ if .Content }}{{ .Content }} +{{- else if .ToolCalls }} +{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }} +{{- end }}{{ if not $last }}<|im_end|> +{{ end }} +{{- else if eq .Role "tool" }}<|im_start|>user + +{{ .Content }} +<|im_end|> +{{ end }} +{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant +{{ end }} +{{- end }} + `, + wantOpeningTag: "", + wantClosingTag: "", + }, + } + for _, c := range cases { + tmpl := template.Must(template.New("test").Parse(c.tmplString)) + openingTag, closingTag := inferThinkingTags(tmpl) + if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag { + t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag) + } + } +} diff --git a/template/template.go b/template/template.go index 5c886cac4a..da910afbd5 100644 --- a/template/template.go +++ b/template/template.go @@ -167,6 +167,10 @@ type Values struct { api.Tools Prompt string Suffix string + Think bool + // whether or not the user explicitly set the thinking flag (vs. it being + // implicitly false). Templates can't see whether `Think` is nil + IsThinkSet bool // forceLegacy is a flag used to test compatibility with legacy templates forceLegacy bool @@ -222,16 +226,20 @@ func (t *Template) Execute(w io.Writer, v Values) error { system, messages := collate(v.Messages) if v.Prompt != "" && v.Suffix != "" { return t.Template.Execute(w, map[string]any{ - "Prompt": v.Prompt, - "Suffix": v.Suffix, - "Response": "", + "Prompt": v.Prompt, + "Suffix": v.Suffix, + "Response": "", + "Think": v.Think, + "IsThinkSet": v.IsThinkSet, }) } else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { return t.Template.Execute(w, map[string]any{ - "System": system, - "Messages": messages, - "Tools": v.Tools, - "Response": "", + "System": system, + "Messages": messages, + "Tools": v.Tools, + "Response": "", + "Think": v.Think, + "IsThinkSet": v.IsThinkSet, }) } @@ -241,9 +249,11 @@ func (t *Template) Execute(w io.Writer, v Values) error { for _, m := range messages { execute := func() error { if err := t.Template.Execute(&b, map[string]any{ - "System": system, - "Prompt": prompt, - "Response": response, + "System": system, + "Prompt": prompt, + "Response": response, + "Think": v.Think, + "IsThinkSet": v.IsThinkSet, }); err != nil { return err } @@ -286,9 +296,11 @@ func (t *Template) Execute(w io.Writer, v Values) error { tree := parse.Tree{Root: nodes.(*parse.ListNode)} if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{ - "System": system, - "Prompt": prompt, - "Response": response, + "System": system, + "Prompt": prompt, + "Response": response, + "Think": v.Think, + "IsThinkSet": v.IsThinkSet, }); err != nil { return err } diff --git a/types/model/capability.go b/types/model/capability.go index fb86894031..cde23cee7a 100644 --- a/types/model/capability.go +++ b/types/model/capability.go @@ -8,6 +8,7 @@ const ( CapabilityInsert = Capability("insert") CapabilityVision = Capability("vision") CapabilityEmbedding = Capability("embedding") + CapabilityThinking = Capability("thinking") ) func (c Capability) String() string {