diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index addce4c945..3ec2c21f19 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -289,6 +289,7 @@ type HarmonyMessageHandler struct { state harmonyMessageState HarmonyParser *HarmonyParser FunctionNameMap *FunctionNameMap + ToolParser *HarmonyToolCallAccumulator } // NewHarmonyMessageHandler creates a new message handler @@ -301,12 +302,16 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler { HeaderEndTag: "<|message|>", }, FunctionNameMap: NewFunctionNameMap(), + ToolParser: &HarmonyToolCallAccumulator{ + state: harmonyToolCallState_Normal, + currentToolName: nil, + }, } } // AddContent processes the content and returns the content, thinking, and tool content. // content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser -func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) { +func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) { contentSb := strings.Builder{} thinkingSb := strings.Builder{} toolContentSb := strings.Builder{} @@ -323,14 +328,14 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo // event.Header.Recipient is the tool name, something like // "browser.search" for a built-in, or "functions.calc" for a // custom one - toolParser.SetToolName(event.Header.Recipient) + h.ToolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Thinking } case "commentary": if event.Header.Recipient != "" { h.state = harmonyMessageState_ToolCalling - toolParser.SetToolName(event.Header.Recipient) + h.ToolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Normal } @@ -353,13 +358,6 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo return contentSb.String(), thinkingSb.String(), toolContentSb.String() } -func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator { - return &HarmonyToolCallAccumulator{ - state: harmonyToolCallState_Normal, - currentToolName: nil, - } -} - type harmonyToolCallState int const ( diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index dcf1af4e83..82bf5b2de8 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -541,7 +541,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("thinking_then_content_streams", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser type step struct { in string wantContent string @@ -554,7 +554,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { {in: "<|end|>", wantContent: ""}, } for i, s := range steps { - content, thinking, tool := handler.AddContent(s.in, tp) + content, thinking, tool := handler.AddContent(s.in) if tool != "" { tp.Add(tool) } @@ -567,7 +567,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("content_streams_as_it_arrives", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|start|>assistant<|message|>Hello", ", world", @@ -575,7 +575,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var got []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if tool != "" { tp.Add(tool) } @@ -595,7 +595,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("thinking_streams_separately_from_content", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|channel|>analysis<|message|>Thinking...", "<|end|>", @@ -604,7 +604,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var got []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if tool != "" { tp.Add(tool) } @@ -624,7 +624,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("partial_tags_buffer_until_complete", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|chan", "nel|>analysis<|mess", @@ -637,7 +637,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { var thinkingPieces []string var contentPieces []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if tool != "" { tp.Add(tool) } @@ -659,7 +659,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("simple_assistant_after_analysis", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|channel|>analysis<|message|>Think", "<|end|>", @@ -668,7 +668,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var contentSb, thinkingSb strings.Builder for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if tool != "" { tp.Add(tool) } @@ -686,12 +686,12 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>", } for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if content != "" || thinking != "" { continue } @@ -711,14 +711,14 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("tool_call_across_chunks", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.CreateToolParser() + tp := handler.ToolParser inputs := []string{ "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+", "2\"}", "<|end|>", } for _, in := range inputs { - content, thinking, tool := handler.AddContent(in, tp) + content, thinking, tool := handler.AddContent(in) if content != "" || thinking != "" { continue } diff --git a/llm/server.go b/llm/server.go index 7bc2ca13df..4740a1fd40 100644 --- a/llm/server.go +++ b/llm/server.go @@ -35,6 +35,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/parser" ) type filteredEnv []string @@ -1350,7 +1351,7 @@ type CompletionRequest struct { Options *api.Options Grammar string // set before sending the request to the subprocess - UseHarmony bool + ParserType parser.TokenParserType PrefillString string } @@ -1364,8 +1365,6 @@ const ( DoneReasonLength // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed DoneReasonConnectionClosed - // DoneReasonTokenRepeatLimit indicates the completion stopped due to a token repeat limit - DoneReasonTokenRepeatLimit ) func (d DoneReason) String() string { @@ -1374,8 +1373,6 @@ func (d DoneReason) String() string { return "length" case DoneReasonStop: return "stop" - case DoneReasonTokenRepeatLimit: - return "token_repeat_limit" default: return "" // closed } diff --git a/parser/token_parser.go b/parser/token_parser.go new file mode 100644 index 0000000000..8124582999 --- /dev/null +++ b/parser/token_parser.go @@ -0,0 +1,126 @@ +package parser + +import ( + "encoding/json" + "errors" + "strings" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/harmony" +) + +type TokenParserType int + +const ( + TokenParserTypeDefault TokenParserType = iota + TokenParserTypeHarmony +) + +type TokenParser struct { + messageHandler MessageHandler + parserEngine ParserInternals + toolParser ToolParser + lastToken string + tokenRepeat int + repeatLimit int +} + +const defaultTokenRepeatLimit = 30 + +type MessageHandler interface { + AddContent(token string) (content, thinking string, toolContent string) +} + +type ParserInternals interface { + AddImplicitStartOrPrefill(prefillString string) +} + +type ToolParser interface { + Add(token string) + Drain() (toolName *string, toolContent string) +} + +// Default implementation for the TokenParser interface as a no-op passthrough +type defaultMessageHandler struct{} + +func (defaultMessageHandler) AddContent(token string) (string, string, string) { + return token, "", "" +} + +type defaultEngine struct{} + +func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {} + +type defaultToolParser struct{} + +func (defaultToolParser) Add(token string) {} + +func (defaultToolParser) Drain() (*string, string) { return nil, "" } + +func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser { + switch parserType { + case TokenParserTypeHarmony: + harmonyMessageHandler := harmony.NewHarmonyMessageHandler() + harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString) + return TokenParser{ + messageHandler: harmonyMessageHandler, + parserEngine: harmonyMessageHandler.HarmonyParser, + toolParser: harmonyMessageHandler.ToolParser, + repeatLimit: defaultTokenRepeatLimit, + } + + default: + return TokenParser{ + messageHandler: defaultMessageHandler{}, + parserEngine: defaultEngine{}, + toolParser: defaultToolParser{}, + repeatLimit: 30, + } + } +} + +func (p *TokenParser) AddContent(token string) (string, string, error) { + if p.repeatLimitReached(token) { + return "", "", errors.New("token repeat limit reached") + } + content, thinking, toolContent := p.messageHandler.AddContent(token) + p.toolParser.Add(toolContent) + return content, thinking, nil +} + +// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached. +func (p *TokenParser) repeatLimitReached(token string) bool { + if p == nil { + return false + } + trimmed := strings.TrimSpace(token) + if trimmed == p.lastToken { + p.tokenRepeat++ + } else { + p.tokenRepeat = 0 + } + p.lastToken = trimmed + + return p.tokenRepeat >= p.repeatLimit +} + +// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level +func (p *TokenParser) Drain() []api.ToolCall { + toolName, toolContent := p.toolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + return nil + } + return []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }, + } + } + return nil +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a40643ef2b..201d55a166 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -30,12 +30,12 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -782,13 +782,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - if req.UseHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillString) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - } + tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) if req.Options == nil { opts := api.DefaultOptions() @@ -872,9 +866,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } - var lastToken string - tokenRepeat := 0 - const tokenRepeatLimit = 30 for { select { @@ -883,23 +874,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if strings.TrimSpace(content) == lastToken { - tokenRepeat++ - } - if tokenRepeat == tokenRepeatLimit { - http.Error(w, "token repeat limit reached", http.StatusInternalServerError) - seq.doneReason = llm.DoneReasonTokenRepeatLimit + var thinking string + var err error + content, thinking, err = tokenParser.AddContent(content) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) close(seq.quit) return } - lastToken = strings.TrimSpace(content) - - var thinking string - if harmonyMessageHandler != nil { - var toolContent string - content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser) - harmonyToolParser.Add(toolContent) - } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, @@ -912,27 +894,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - var toolCalls []api.ToolCall - if harmonyMessageHandler != nil { - // these tools still need to be transformed to the original function name - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError) - close(seq.quit) - return - } - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }) - } - } - + toolCalls := tokenParser.Drain() if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ ToolCalls: toolCalls, Done: true, diff --git a/server/routes.go b/server/routes.go index 73ea5fea47..ac4df4a469 100644 --- a/server/routes.go +++ b/server/routes.go @@ -36,6 +36,7 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/openai" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -196,6 +197,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { } useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw + var parserType parser.TokenParserType + if useHarmony { + parserType = parser.TokenParserTypeHarmony + } else { + parserType = parser.TokenParserTypeDefault + } var functionNameMap *harmony.FunctionNameMap if useHarmony { @@ -347,7 +354,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { Images: images, Format: req.Format, Options: opts, - UseHarmony: useHarmony, + ParserType: parserType, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -1592,6 +1599,12 @@ func (s *Server) ChatHandler(c *gin.Context) { msgs = filterThinkTags(msgs, m) useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) + var parserType parser.TokenParserType + if useHarmony { + parserType = parser.TokenParserTypeHarmony + } else { + parserType = parser.TokenParserTypeDefault + } processedTools := req.Tools var functionNameMap *harmony.FunctionNameMap @@ -1662,7 +1675,7 @@ func (s *Server) ChatHandler(c *gin.Context) { Images: images, Format: req.Format, Options: opts, - UseHarmony: useHarmony, + ParserType: parserType, PrefillString: prefillString, }, func(r llm.CompletionResponse) { res := api.ChatResponse{