diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index 3ec2c21f19..addce4c945 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -289,7 +289,6 @@ type HarmonyMessageHandler struct { state harmonyMessageState HarmonyParser *HarmonyParser FunctionNameMap *FunctionNameMap - ToolParser *HarmonyToolCallAccumulator } // NewHarmonyMessageHandler creates a new message handler @@ -302,16 +301,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler { HeaderEndTag: "<|message|>", }, FunctionNameMap: NewFunctionNameMap(), - ToolParser: &HarmonyToolCallAccumulator{ - state: harmonyToolCallState_Normal, - currentToolName: nil, - }, } } // AddContent processes the content and returns the content, thinking, and tool content. // content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser -func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) { +func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) { contentSb := strings.Builder{} thinkingSb := strings.Builder{} toolContentSb := strings.Builder{} @@ -328,14 +323,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri // event.Header.Recipient is the tool name, something like // "browser.search" for a built-in, or "functions.calc" for a // custom one - h.ToolParser.SetToolName(event.Header.Recipient) + toolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Thinking } case "commentary": if event.Header.Recipient != "" { h.state = harmonyMessageState_ToolCalling - h.ToolParser.SetToolName(event.Header.Recipient) + toolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Normal } @@ -358,6 +353,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri return contentSb.String(), thinkingSb.String(), toolContentSb.String() } +func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator { + return &HarmonyToolCallAccumulator{ + state: harmonyToolCallState_Normal, + currentToolName: nil, + } +} + type harmonyToolCallState int const ( diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index 82bf5b2de8..dcf1af4e83 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -541,7 +541,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("thinking_then_content_streams", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() type step struct { in string wantContent string @@ -554,7 +554,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { {in: "<|end|>", wantContent: ""}, } for i, s := range steps { - content, thinking, tool := handler.AddContent(s.in) + content, thinking, tool := handler.AddContent(s.in, tp) if tool != "" { tp.Add(tool) } @@ -567,7 +567,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("content_streams_as_it_arrives", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|start|>assistant<|message|>Hello", ", world", @@ -575,7 +575,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var got []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if tool != "" { tp.Add(tool) } @@ -595,7 +595,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("thinking_streams_separately_from_content", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|channel|>analysis<|message|>Thinking...", "<|end|>", @@ -604,7 +604,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var got []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if tool != "" { tp.Add(tool) } @@ -624,7 +624,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("partial_tags_buffer_until_complete", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|chan", "nel|>analysis<|mess", @@ -637,7 +637,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { var thinkingPieces []string var contentPieces []string for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if tool != "" { tp.Add(tool) } @@ -659,7 +659,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("simple_assistant_after_analysis", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|channel|>analysis<|message|>Think", "<|end|>", @@ -668,7 +668,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { } var contentSb, thinkingSb strings.Builder for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if tool != "" { tp.Add(tool) } @@ -686,12 +686,12 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>", } for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if content != "" || thinking != "" { continue } @@ -711,14 +711,14 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { t.Run("tool_call_across_chunks", func(t *testing.T) { handler := NewHarmonyMessageHandler() handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser + tp := handler.CreateToolParser() inputs := []string{ "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+", "2\"}", "<|end|>", } for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) + content, thinking, tool := handler.AddContent(in, tp) if content != "" || thinking != "" { continue } diff --git a/llm/server.go b/llm/server.go index 9100b69788..45a9ad14c9 100644 --- a/llm/server.go +++ b/llm/server.go @@ -35,7 +35,6 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" - "github.com/ollama/ollama/parser" ) type filteredEnv []string @@ -1350,7 +1349,7 @@ type CompletionRequest struct { Options *api.Options Grammar string // set before sending the request to the subprocess - ParserType parser.TokenParserType + UseHarmony bool PrefillString string } @@ -1364,6 +1363,8 @@ const ( DoneReasonLength // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed DoneReasonConnectionClosed + // DoneReasonTokenRepeatLimit indicates the completion stopped due to a token repeat limit + DoneReasonTokenRepeatLimit ) func (d DoneReason) String() string { @@ -1372,6 +1373,8 @@ func (d DoneReason) String() string { return "length" case DoneReasonStop: return "stop" + case DoneReasonTokenRepeatLimit: + return "token_repeat_limit" default: return "" // closed } diff --git a/parser/token_parser.go b/parser/token_parser.go deleted file mode 100644 index 8124582999..0000000000 --- a/parser/token_parser.go +++ /dev/null @@ -1,126 +0,0 @@ -package parser - -import ( - "encoding/json" - "errors" - "strings" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/harmony" -) - -type TokenParserType int - -const ( - TokenParserTypeDefault TokenParserType = iota - TokenParserTypeHarmony -) - -type TokenParser struct { - messageHandler MessageHandler - parserEngine ParserInternals - toolParser ToolParser - lastToken string - tokenRepeat int - repeatLimit int -} - -const defaultTokenRepeatLimit = 30 - -type MessageHandler interface { - AddContent(token string) (content, thinking string, toolContent string) -} - -type ParserInternals interface { - AddImplicitStartOrPrefill(prefillString string) -} - -type ToolParser interface { - Add(token string) - Drain() (toolName *string, toolContent string) -} - -// Default implementation for the TokenParser interface as a no-op passthrough -type defaultMessageHandler struct{} - -func (defaultMessageHandler) AddContent(token string) (string, string, string) { - return token, "", "" -} - -type defaultEngine struct{} - -func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {} - -type defaultToolParser struct{} - -func (defaultToolParser) Add(token string) {} - -func (defaultToolParser) Drain() (*string, string) { return nil, "" } - -func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser { - switch parserType { - case TokenParserTypeHarmony: - harmonyMessageHandler := harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString) - return TokenParser{ - messageHandler: harmonyMessageHandler, - parserEngine: harmonyMessageHandler.HarmonyParser, - toolParser: harmonyMessageHandler.ToolParser, - repeatLimit: defaultTokenRepeatLimit, - } - - default: - return TokenParser{ - messageHandler: defaultMessageHandler{}, - parserEngine: defaultEngine{}, - toolParser: defaultToolParser{}, - repeatLimit: 30, - } - } -} - -func (p *TokenParser) AddContent(token string) (string, string, error) { - if p.repeatLimitReached(token) { - return "", "", errors.New("token repeat limit reached") - } - content, thinking, toolContent := p.messageHandler.AddContent(token) - p.toolParser.Add(toolContent) - return content, thinking, nil -} - -// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached. -func (p *TokenParser) repeatLimitReached(token string) bool { - if p == nil { - return false - } - trimmed := strings.TrimSpace(token) - if trimmed == p.lastToken { - p.tokenRepeat++ - } else { - p.tokenRepeat = 0 - } - p.lastToken = trimmed - - return p.tokenRepeat >= p.repeatLimit -} - -// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level -func (p *TokenParser) Drain() []api.ToolCall { - toolName, toolContent := p.toolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - return nil - } - return []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }, - } - } - return nil -} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 676e5186f8..5da8ca3cbb 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -29,12 +29,12 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -781,7 +781,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) + var harmonyMessageHandler *harmony.HarmonyMessageHandler + var harmonyToolParser *harmony.HarmonyToolCallAccumulator + if req.UseHarmony { + harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillString) + harmonyToolParser = harmonyMessageHandler.CreateToolParser() + } if req.Options == nil { opts := api.DefaultOptions() @@ -865,6 +871,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { http.Error(w, "could not find an available sequence", http.StatusInternalServerError) return } + var lastToken string + tokenRepeat := 0 + const tokenRepeatLimit = 30 for { select { @@ -873,14 +882,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - var thinking string - var err error - content, thinking, err = tokenParser.AddContent(content) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + if strings.TrimSpace(content) == lastToken { + tokenRepeat++ + } + if tokenRepeat == tokenRepeatLimit { + http.Error(w, "token repeat limit reached", http.StatusInternalServerError) + seq.doneReason = llm.DoneReasonTokenRepeatLimit close(seq.quit) return } + lastToken = strings.TrimSpace(content) + + var thinking string + if harmonyMessageHandler != nil { + var toolContent string + content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser) + harmonyToolParser.Add(toolContent) + } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, @@ -893,7 +911,27 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - toolCalls := tokenParser.Drain() + var toolCalls []api.ToolCall + if harmonyMessageHandler != nil { + // these tools still need to be transformed to the original function name + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError) + close(seq.quit) + return + } + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }) + } + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ ToolCalls: toolCalls, Done: true, diff --git a/server/routes.go b/server/routes.go index 8dd1b217ae..da5e22f687 100644 --- a/server/routes.go +++ b/server/routes.go @@ -36,7 +36,6 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -197,12 +196,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { } useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw - var parserType parser.TokenParserType - if useHarmony { - parserType = parser.TokenParserTypeHarmony - } else { - parserType = parser.TokenParserTypeDefault - } var functionNameMap *harmony.FunctionNameMap if useHarmony { @@ -354,7 +347,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { Images: images, Format: req.Format, Options: opts, - ParserType: parserType, + UseHarmony: useHarmony, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -1600,12 +1593,6 @@ func (s *Server) ChatHandler(c *gin.Context) { msgs = filterThinkTags(msgs, m) useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) - var parserType parser.TokenParserType - if useHarmony { - parserType = parser.TokenParserTypeHarmony - } else { - parserType = parser.TokenParserTypeDefault - } processedTools := req.Tools var functionNameMap *harmony.FunctionNameMap @@ -1676,7 +1663,7 @@ func (s *Server) ChatHandler(c *gin.Context) { Images: images, Format: req.Format, Options: opts, - ParserType: parserType, + UseHarmony: useHarmony, PrefillString: prefillString, }, func(r llm.CompletionResponse) { res := api.ChatResponse{