cleanup passing in harmony flag and add generate support

This commit is contained in:
ParthSareen
2025-09-02 09:58:36 -07:00
parent 5bc783b58e
commit 40d3436cd1
3 changed files with 39 additions and 51 deletions

View File

@@ -31,7 +31,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
@@ -1349,7 +1348,7 @@ type CompletionRequest struct {
Options *api.Options
Grammar string // set before sending the request to the subprocess
FunctionNameMap *harmony.FunctionNameMap
UseHarmony bool
PrefillContent *bool
}

View File

@@ -776,9 +776,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if req.FunctionNameMap != nil {
if req.UseHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.FunctionNameMap = req.FunctionNameMap
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillContent)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
@@ -893,10 +892,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} else {
var toolCalls []api.ToolCall
if harmonyMessageHandler != nil {
// these tools still need to be transformed to the original function name
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)

View File

@@ -196,12 +196,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
var functionNameMap *harmony.FunctionNameMap
if useHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
functionNameMap = harmony.NewFunctionNameMap()
}
// Validate Think value: string values currently only allowed for gptoss models
@@ -349,12 +347,15 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Images: images,
Format: req.Format,
Options: opts,
UseHarmony: useHarmony,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
Thinking: cr.Thinking,
ToolCalls: cr.ToolCalls,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
@@ -363,12 +364,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
},
}
if res.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if useHarmony {
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
res.Response = content
res.Thinking = thinking
harmonyToolParser.Add(toolContent)
} else if thinkingState != nil {
for i, tool := range res.ToolCalls {
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
}
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
ch <- res
}
return
}
if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking
res.Response = content
@@ -379,30 +390,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if cr.Done {
if useHarmony {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
@@ -1690,7 +1677,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
Images: images,
Format: req.Format,
Options: opts,
FunctionNameMap: functionNameMap,
UseHarmony: useHarmony,
PrefillContent: prefillContentOrThinking,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
@@ -1713,6 +1700,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
if useHarmony {
// only send messages with meaningful content (empty messages confuse clients)
for i, tool := range res.Message.ToolCalls {
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
}
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res
}