diff --git a/server/routes.go b/server/routes.go index cc8913537e..a208075d92 100644 --- a/server/routes.go +++ b/server/routes.go @@ -46,18 +46,6 @@ import ( "github.com/ollama/ollama/version" ) -func shouldUseHarmony(model *Model) bool { - if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { - // heuristic to check whether the template expects to be parsed via harmony: - // search for harmony tags that are nearly always used - if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") { - return true - } - } - - return false -} - func experimentEnabled(name string) bool { return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) } @@ -207,7 +195,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - useHarmony := shouldUseHarmony(m) && !req.Raw + useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw var harmonyMessageHandler *harmony.HarmonyMessageHandler var harmonyToolParser *harmony.HarmonyToolCallAccumulator if useHarmony { @@ -1616,27 +1604,36 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - - useHarmony := shouldUseHarmony(m) + useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) processedTools := req.Tools + var functionNameMap *harmony.FunctionNameMap + var prefillContentOrThinking *bool if useHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + functionNameMap = harmony.NewFunctionNameMap() var lastMessage *api.Message if len(msgs) > 0 { lastMessage = &msgs[len(msgs)-1] } - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() + // prefill content or thinking flag if the last message is an assistant message + if lastMessage != nil && lastMessage.Role == "assistant" { + if lastMessage.Content != "" { + trueVal := true + // true sets content to be prefilled + prefillContentOrThinking = &trueVal + } else if lastMessage.Thinking != "" { + // false sets thinking to be prefilled + falseVal := false + prefillContentOrThinking = &falseVal + } + } // make a copy of tools to pass to the chat prompt. Function names may be // renamed to be valid Harmony function names. processedTools = make([]api.Tool, len(req.Tools)) copy(processedTools, req.Tools) for i, tool := range processedTools { - processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name) + processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name) } } @@ -1685,15 +1682,17 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + FunctionNameMap: functionNameMap, + PrefillContent: prefillContentOrThinking, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, + Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls}, Done: r.Done, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, @@ -1709,31 +1708,10 @@ func (s *Server) ChatHandler(c *gin.Context) { } if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) - res.Message.Content = content - res.Message.Thinking = thinking - harmonyToolParser.Add(toolContent) - - if r.Done { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName) - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} - } - } - // only send messages with meaningful content (empty messages confuse clients) if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { ch <- res } - return }