diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index addce4c945..a51819ddab 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -3,29 +3,15 @@ package harmony import ( "fmt" "log/slog" - "slices" "strings" "unicode" "github.com/ollama/ollama/api" "github.com/ollama/ollama/logutil" - "github.com/ollama/ollama/template" ) type harmonyParserState int -func ShouldUseHarmony(modelFamily string, template *template.Template) bool { - if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) { - // heuristic to check whether the template expects to be parsed via harmony: - // search for harmony tags that are nearly always used - if template.Contains("<|start|>") && template.Contains("<|end|>") { - return true - } - } - - return false -} - const ( harmonyParserState_LookingForMessageStart harmonyParserState = iota harmonyParserState_ParsingHeader @@ -89,28 +75,18 @@ func (s *HarmonyParser) AddImplicitStart() { s.acc.WriteString("<|start|>assistant") } -func Prefill(lastMessage api.Message) string { - if lastMessage.Role != "assistant" { - return "" - } - - switch { - case strings.TrimSpace(lastMessage.Content) != "": - return "<|start|>assistant<|channel|>final<|message|>" - case strings.TrimSpace(lastMessage.Thinking) != "": - return "<|start|>assistant<|channel|>analysis<|message|>" - default: - return "" - } -} - -// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided -func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) { - if strings.TrimSpace(prefillString) != "" { - s.acc.WriteString(prefillString) - } else { - s.AddImplicitStart() +func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) { + if lastMessage != nil && lastMessage.Role == "assistant" { + // handle prefilling conditions + if lastMessage.Content != "" { + s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>") + return + } else if lastMessage.Thinking != "" { + s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>") + return + } } + s.AddImplicitStart() } func (s *HarmonyParser) AddContent(content string) []HarmonyEvent { diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index dcf1af4e83..b988a018f3 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -3,7 +3,6 @@ package harmony import ( "fmt" "reflect" - "strings" "testing" ) @@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) { }) } } - -func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { - t.Run("thinking_then_content_streams", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - type step struct { - in string - wantContent string - wantThinking string - } - steps := []step{ - {in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."}, - {in: "<|end|>", wantThinking: ""}, - {in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"}, - {in: "<|end|>", wantContent: ""}, - } - for i, s := range steps { - content, thinking, tool := handler.AddContent(s.in, tp) - if tool != "" { - tp.Add(tool) - } - if content != s.wantContent || thinking != s.wantThinking { - t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking) - } - } - }) - - t.Run("content_streams_as_it_arrives", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|start|>assistant<|message|>Hello", - ", world", - "!<|end|>", - } - var got []string - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if tool != "" { - tp.Add(tool) - } - if thinking != "" { - t.Fatalf("unexpected thinking %q", thinking) - } - if content != "" { - got = append(got, content) - } - } - want := []string{"Hello", ", world", "!"} - if !reflect.DeepEqual(got, want) { - t.Fatalf("content pieces mismatch: got %v want %v", got, want) - } - }) - - t.Run("thinking_streams_separately_from_content", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|channel|>analysis<|message|>Thinking...", - "<|end|>", - "<|start|>assistant<|message|>Answer", - "<|end|>", - } - var got []string - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if tool != "" { - tp.Add(tool) - } - if thinking != "" { - got = append(got, thinking) - } - if content != "" { - got = append(got, content) - } - } - want := []string{"Thinking...", "Answer"} - if !reflect.DeepEqual(got, want) { - t.Fatalf("content pieces mismatch: got %v want %v", got, want) - } - }) - - t.Run("partial_tags_buffer_until_complete", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|chan", - "nel|>analysis<|mess", - "age|>Deep ", - "thought", - "<|end|>", - "<|start|>assistant<|message|>Done", - "<|end|>", - } - var thinkingPieces []string - var contentPieces []string - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if tool != "" { - tp.Add(tool) - } - if thinking != "" { - thinkingPieces = append(thinkingPieces, thinking) - } - if content != "" { - contentPieces = append(contentPieces, content) - } - } - if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) { - t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want) - } - if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) { - t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want) - } - }) - - t.Run("simple_assistant_after_analysis", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|channel|>analysis<|message|>Think", - "<|end|>", - "<|start|>assistant<|message|>Answer", - "<|end|>", - } - var contentSb, thinkingSb strings.Builder - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if tool != "" { - tp.Add(tool) - } - contentSb.WriteString(content) - thinkingSb.WriteString(thinking) - } - if contentSb.String() != "Answer" { - t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer") - } - if thinkingSb.String() != "Think" { - t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think") - } - }) - - t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>", - } - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if content != "" || thinking != "" { - continue - } - if tool != "" { - tp.Add(tool) - } - } - name, args := tp.Drain() - if name == nil || *name != "functions.calculate" { - t.Fatalf("unexpected tool name: %v", name) - } - if got, want := args, "{\"expression\":\"2+2\"}"; got != want { - t.Fatalf("unexpected tool args: got %s want %s", got, want) - } - }) - - t.Run("tool_call_across_chunks", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() - inputs := []string{ - "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+", - "2\"}", - "<|end|>", - } - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) - if content != "" || thinking != "" { - continue - } - if tool != "" { - tp.Add(tool) - } - } - name, args := tp.Drain() - if name == nil || *name != "functions.calculate" { - t.Fatalf("unexpected tool name: %v", name) - } - if got, want := args, "{\"expression\":\"2+2\"}"; got != want { - t.Fatalf("unexpected tool args: got %s want %s", got, want) - } - }) -} diff --git a/llm/server.go b/llm/server.go index 45a9ad14c9..75f049bc05 100644 --- a/llm/server.go +++ b/llm/server.go @@ -1348,9 +1348,7 @@ type CompletionRequest struct { Images []ImageData Options *api.Options - Grammar string // set before sending the request to the subprocess - UseHarmony bool - PrefillString string + Grammar string // set before sending the request to the subprocess } // DoneReason represents the reason why a completion response is done @@ -1363,8 +1361,6 @@ const ( DoneReasonLength // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed DoneReasonConnectionClosed - // DoneReasonTokenRepeatLimit indicates the completion stopped due to a token repeat limit - DoneReasonTokenRepeatLimit ) func (d DoneReason) String() string { @@ -1373,23 +1369,19 @@ func (d DoneReason) String() string { return "length" case DoneReasonStop: return "stop" - case DoneReasonTokenRepeatLimit: - return "token_repeat_limit" default: return "" // closed } } type CompletionResponse struct { - Content string `json:"content"` - Thinking string `json:"thinking"` - ToolCalls []api.ToolCall `json:"tool_calls"` - DoneReason DoneReason `json:"done_reason"` - Done bool `json:"done"` - PromptEvalCount int `json:"prompt_eval_count"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration"` - EvalCount int `json:"eval_count"` - EvalDuration time.Duration `json:"eval_duration"` + Content string `json:"content"` + DoneReason DoneReason `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { @@ -1507,8 +1499,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } switch { - // TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future - case strings.TrimSpace(c.Content) == lastToken && c.Content != "": + case strings.TrimSpace(c.Content) == lastToken: tokenRepeat++ default: lastToken = strings.TrimSpace(c.Content) @@ -1521,14 +1512,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return ctx.Err() } + if c.Content != "" { + fn(CompletionResponse{ + Content: c.Content, + }) + } + if c.Done { fn(c) return nil } - - if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 { - fn(c) - } } } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 5da8ca3cbb..1081a1f555 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -29,7 +29,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" @@ -781,14 +780,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - if req.UseHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillString) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - } - if req.Options == nil { opts := api.DefaultOptions() req.Options = &opts @@ -871,9 +862,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } - var lastToken string - tokenRepeat := 0 - const tokenRepeatLimit = 30 for { select { @@ -882,27 +870,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if strings.TrimSpace(content) == lastToken { - tokenRepeat++ - } - if tokenRepeat == tokenRepeatLimit { - http.Error(w, "token repeat limit reached", http.StatusInternalServerError) - seq.doneReason = llm.DoneReasonTokenRepeatLimit - close(seq.quit) - return - } - lastToken = strings.TrimSpace(content) - - var thinking string - if harmonyMessageHandler != nil { - var toolContent string - content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser) - harmonyToolParser.Add(toolContent) - } - if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, - Thinking: thinking, + Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) @@ -911,29 +880,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - var toolCalls []api.ToolCall - if harmonyMessageHandler != nil { - // these tools still need to be transformed to the original function name - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError) - close(seq.quit) - return - } - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }) - } - } - if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - ToolCalls: toolCalls, Done: true, DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, diff --git a/server/routes.go b/server/routes.go index da5e22f687..5114cb74fa 100644 --- a/server/routes.go +++ b/server/routes.go @@ -46,6 +46,18 @@ import ( "github.com/ollama/ollama/version" ) +func shouldUseHarmony(model *Model) bool { + if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { + // heuristic to check whether the template expects to be parsed via harmony: + // search for harmony tags that are nearly always used + if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") { + return true + } + } + + return false +} + func experimentEnabled(name string) bool { return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) } @@ -195,11 +207,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw - var functionNameMap *harmony.FunctionNameMap - + useHarmony := shouldUseHarmony(m) && !req.Raw + var harmonyMessageHandler *harmony.HarmonyMessageHandler + var harmonyToolParser *harmony.HarmonyToolCallAccumulator if useHarmony { - functionNameMap = harmony.NewFunctionNameMap() + harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + harmonyMessageHandler.HarmonyParser.AddImplicitStart() + harmonyToolParser = harmonyMessageHandler.CreateToolParser() } // Validate Think value: string values currently only allowed for gptoss models @@ -343,19 +357,16 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, - UseHarmony: useHarmony, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), Response: cr.Content, Done: cr.Done, - Thinking: cr.Thinking, - ToolCalls: cr.ToolCalls, Metrics: api.Metrics{ PromptEvalCount: cr.PromptEvalCount, PromptEvalDuration: cr.PromptEvalDuration, @@ -364,22 +375,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } - if res.Done { - res.DoneReason = cr.DoneReason.String() - res.TotalDuration = time.Since(checkpointStart) - res.LoadDuration = checkpointLoaded.Sub(checkpointStart) - } - if useHarmony { - for i, tool := range res.ToolCalls { - res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) - } - if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done { - ch <- res - } - return - } - if thinkingState != nil { + content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) + res.Response = content + res.Thinking = thinking + harmonyToolParser.Add(toolContent) + } else if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking res.Response = content @@ -390,6 +391,30 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { + if useHarmony { + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) + ch <- gin.H{"error": errStr} + return + } + + res.ToolCalls = append(res.ToolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }) + } + } + + res.DoneReason = cr.DoneReason.String() + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + if !req.Raw { tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) if err != nil { @@ -1592,21 +1617,27 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) + var harmonyMessageHandler *harmony.HarmonyMessageHandler + var harmonyToolParser *harmony.HarmonyToolCallAccumulator + + useHarmony := shouldUseHarmony(m) processedTools := req.Tools - var functionNameMap *harmony.FunctionNameMap - var prefillString string - // TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner if useHarmony { - prefillString = harmony.Prefill(msgs[len(msgs)-1]) - functionNameMap = harmony.NewFunctionNameMap() + harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + var lastMessage *api.Message + if len(msgs) > 0 { + lastMessage = &msgs[len(msgs)-1] + } + harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) + harmonyToolParser = harmonyMessageHandler.CreateToolParser() + // make a copy of tools to pass to the chat prompt. Function names may be // renamed to be valid Harmony function names. processedTools = make([]api.Tool, len(req.Tools)) copy(processedTools, req.Tools) for i, tool := range processedTools { - processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name) + processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name) } } @@ -1659,17 +1690,15 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, - UseHarmony: useHarmony, - PrefillString: prefillString, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls}, + Message: api.Message{Role: "assistant", Content: r.Content}, Done: r.Done, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, @@ -1685,13 +1714,31 @@ func (s *Server) ChatHandler(c *gin.Context) { } if useHarmony { - for i, tool := range res.Message.ToolCalls { - res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) + content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) + res.Message.Content = content + res.Message.Thinking = thinking + harmonyToolParser.Add(toolContent) + + if r.Done { + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName) + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) + ch <- gin.H{"error": errStr} + return + } + res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} + } } + // only send messages with meaningful content (empty messages confuse clients) if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { ch <- res } + return } diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index bcb0208865..b1ede4e39e 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "encoding/json" + "net/http" "strings" "testing" "time" @@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "content streams as it arrives", steps: []step{ { - input: llm.CompletionResponse{Content: "Hello", Done: false}, + input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false}, wantContent: "Hello", }, { @@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { wantContent: ", world", }, { - input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "!", }, }, @@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "thinking streams separately from content", steps: []step{ { - input: llm.CompletionResponse{Thinking: "Thinking...", Done: false}, + input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false}, wantThinking: "Thinking...", }, { - input: llm.CompletionResponse{Content: "Answer", Done: false}, - wantContent: "Answer", + input: llm.CompletionResponse{Content: "<|end|>", Done: false}, + // No output expected - just closes the analysis message and resets state to normal }, { - input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false}, + wantContent: "Answer", // After message end, state is reset to normal + }, + { + input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + // No output expected - just closes the assistant message }, }, }, @@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "partial tags buffer until complete", steps: []step{ { - input: llm.CompletionResponse{Thinking: "Deep ", Done: false}, + input: llm.CompletionResponse{Content: "<|chan", Done: false}, + // No output - partial tag + }, + { + input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false}, + // No output - still building tags + }, + { + input: llm.CompletionResponse{Content: "age|>Deep ", Done: false}, wantThinking: "Deep ", }, { - input: llm.CompletionResponse{Thinking: "thought", Done: false}, + input: llm.CompletionResponse{Content: "thought<|end|>", Done: false}, wantThinking: "thought", }, { - input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop}, - wantContent: "Done", + input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + wantContent: "Done", // After message end, state is reset to normal }, }, }, @@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "simple assistant after analysis", steps: []step{ { - input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "Answer", wantThinking: "Think", }, @@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call parsed and returned correctly", steps: []step{ { - input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "The weather is sunny", wantToolCalls: []api.ToolCall{ { @@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call with streaming JSON across chunks", steps: []step{ { - input: llm.CompletionResponse{Done: false}, + input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false}, + // No output yet - incomplete JSON }, { - input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true}, + input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false}, + // Still no output - incomplete JSON + }, + { + input: llm.CompletionResponse{Content: "2\"}", Done: true}, wantToolCalls: []api.ToolCall{ { Function: api.ToolCallFunction{ @@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { gin.SetMode(gin.TestMode) mockResponses := []llm.CompletionResponse{ - {Content: "First ", Done: false}, + {Content: "<|message|>First ", Done: false}, {Content: "chunk ", Done: false}, - {Content: "here", Done: true, DoneReason: llm.DoneReasonStop}, + {Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, } mock := mockRunner{ @@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks) } } + +func TestChatHarmonyParserStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + type expectedChunk struct { + afterResponse int // Which mock response this chunk should appear after + content string // Expected content in this chunk + thinking string // Expected thinking in this chunk + } + + testCases := []struct { + name string + mockResponses []llm.CompletionResponse + expectedChunks []expectedChunk + wantContent string + wantThinking string + }{ + { + name: "simple message without thinking", + mockResponses: []llm.CompletionResponse{ + {Content: "<|start|>assistant<|message|>Hello, ", Done: false}, + {Content: "how can I help?", Done: false}, + {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 1, content: "Hello, "}, + {afterResponse: 2, content: "how can I help?"}, + }, + wantContent: "Hello, how can I help?", + }, + { + name: "message with analysis channel for thinking", + mockResponses: []llm.CompletionResponse{ + {Content: "<|channel|>analysis<|message|>", Done: false}, + {Content: "Let me think ", Done: false}, + {Content: "about this problem...", Done: false}, + {Content: "<|end|>", Done: false}, + {Content: "<|start|>assistant<|message|>", Done: false}, + {Content: "The answer ", Done: false}, + {Content: "is 42", Done: false}, + {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 2, thinking: "Let me think "}, + {afterResponse: 3, thinking: "about this problem..."}, + {afterResponse: 6, content: "The answer "}, + {afterResponse: 7, content: "is 42"}, + }, + wantContent: "The answer is 42", + wantThinking: "Let me think about this problem...", + }, + { + name: "streaming with partial tags across boundaries", + mockResponses: []llm.CompletionResponse{ + {Content: "<|chan", Done: false}, + {Content: "nel|>analy", Done: false}, + {Content: "sis<|mess", Done: false}, + {Content: "age|>Think", Done: false}, + {Content: "ing deeply...<|end|>", Done: false}, + {Content: "<|start|>assi", Done: false}, + {Content: "stant<|message|>Result ", Done: false}, + {Content: "computed<|e", Done: false}, + {Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 4, thinking: "Think"}, + {afterResponse: 5, thinking: "ing deeply..."}, + {afterResponse: 7, content: "Result "}, + {afterResponse: 8, content: "computed"}, + }, + wantContent: "Result computed", + wantThinking: "Thinking deeply...", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Channel to synchronize mock responses with chunk verification + responsesSent := make(chan int, len(tc.mockResponses)) + + mock := mockRunner{ + CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + // Send mock responses one at a time, notifying when each is sent + for i, resp := range tc.mockResponses { + fn(resp) + responsesSent <- i + 1 + } + close(responsesSent) + return nil + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: discover.GetGPUInfo, + getCpuFn: discover.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { + req.successCh <- &runnerRef{ + llama: &mock, + } + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + // Create a minimal model + _, digest := createHarmonyTestModel(t) + + // Create model with passthrough template + stream := false + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "harmony-test", + Files: map[string]string{"file.gguf": digest}, + Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("failed to create model: %d", w.Code) + } + + // Test chat endpoint with streaming + streamTrue := true + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "harmony-test", + Messages: []api.Message{{Role: "user", Content: "Hello"}}, + Stream: &streamTrue, + Tools: getTestTools(), + }) + + if w.Code != http.StatusOK { + t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String()) + } + + // Parse streaming response + var chunks []api.ChatResponse + var content, thinking strings.Builder + + decoder := json.NewDecoder(w.Body) + for decoder.More() { + var chunk api.ChatResponse + if err := decoder.Decode(&chunk); err != nil { + t.Fatalf("failed to decode chunk: %v", err) + } + chunks = append(chunks, chunk) + + // Accumulate content and thinking from each chunk + content.WriteString(chunk.Message.Content) + thinking.WriteString(chunk.Message.Thinking) + + // Debug output + t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done) + } + + // Verify we got streaming chunks + if len(chunks) == 0 { + t.Fatal("expected streaming chunks, got none") + } + + gotContent := content.String() + gotThinking := thinking.String() + + if gotContent != tc.wantContent { + t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent) + } + if gotThinking != tc.wantThinking { + t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking) + } + + // Verify last chunk has done=true + lastChunk := chunks[len(chunks)-1] + if !lastChunk.Done { + t.Error("expected last chunk to have done=true") + } + }) + } +}