harmony: simplify prefill, add marshalling for functions, and update harmony check

This commit is contained in:
ParthSareen
2025-08-22 15:45:11 -07:00
parent 1d09e01431
commit 72189c6d6e

View File

@@ -1,17 +1,32 @@
package harmony package harmony
import ( import (
"encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"maps"
"slices"
"strings" "strings"
"unicode" "unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/template"
) )
type harmonyParserState int type harmonyParserState int
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if template.Contains("<|start|>") && template.Contains("<|end|>") {
return true
}
}
return false
}
const ( const (
harmonyParserState_LookingForMessageStart harmonyParserState = iota harmonyParserState_LookingForMessageStart harmonyParserState = iota
harmonyParserState_ParsingHeader harmonyParserState_ParsingHeader
@@ -75,17 +90,18 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant") s.acc.WriteString("<|start|>assistant")
} }
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) { // AddImplicitStartOrPrefill adds content or thinking to the accumulator else adds start tag
if lastMessage != nil && lastMessage.Role == "assistant" { func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillContentOrThinking *bool) {
// handle prefilling conditions if prefillContentOrThinking != nil {
if lastMessage.Content != "" { if *prefillContentOrThinking {
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>") s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
return return
} else if lastMessage.Thinking != "" { } else {
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>") s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
return return
} }
} }
s.AddImplicitStart() s.AddImplicitStart()
} }
@@ -377,6 +393,38 @@ type FunctionNameMap struct {
harmonyToUser map[string]string harmonyToUser map[string]string
} }
func (m FunctionNameMap) MarshalJSON() ([]byte, error) {
// necessary to avoid exposing map internals
type alias struct {
UserToHarmony map[string]string `json:"userToHarmony"`
HarmonyToUser map[string]string `json:"harmonyToUser"`
}
return json.Marshal(alias{
UserToHarmony: m.userToHarmony,
HarmonyToUser: m.harmonyToUser,
})
}
func (m *FunctionNameMap) UnmarshalJSON(b []byte) error {
type alias struct {
UserToHarmony map[string]string `json:"userToHarmony"`
HarmonyToUser map[string]string `json:"harmonyToUser"`
}
var a alias
if err := json.Unmarshal(b, &a); err != nil {
return err
}
if m.userToHarmony == nil {
m.userToHarmony = make(map[string]string)
}
if m.harmonyToUser == nil {
m.harmonyToUser = make(map[string]string)
}
maps.Copy(m.userToHarmony, a.UserToHarmony)
maps.Copy(m.harmonyToUser, a.HarmonyToUser)
return nil
}
func NewFunctionNameMap() *FunctionNameMap { func NewFunctionNameMap() *FunctionNameMap {
return &FunctionNameMap{ return &FunctionNameMap{
userToHarmony: make(map[string]string), userToHarmony: make(map[string]string),