mirror of
https://github.com/ollama/ollama.git
synced 2025-08-24 15:51:09 +02:00
harmony: convert fn names to be valid ts identifiers
In <https://github.com/ollama/ollama/issues/11704#issuecomment-3177380197> I noticed that hyphens in function names could possibly cause the model to become confused. Later in that issue I found other explanations, but at a minimum tool names with spaces in them are confusing to the model because of the prompt format. In this change I create a mapper that converts arbitrary tool names into valid typescript identifiers. It's a little overly strict in that it doesn't allow all unicode characters that might be valid in ts identifiers, but it's still very permissive. Since mappings aren't reversible, we must temporarily store this mapping in order to unmap it if the model comes back with a call. We also handle the case where multiple mappings collide into the same mapping and append a counter to the end to make them unique
This commit is contained in:
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -275,8 +276,9 @@ const (
|
|||||||
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
||||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||||
type HarmonyMessageHandler struct {
|
type HarmonyMessageHandler struct {
|
||||||
state harmonyMessageState
|
state harmonyMessageState
|
||||||
harmonyParser *HarmonyParser
|
harmonyParser *HarmonyParser
|
||||||
|
functionNameMap *FunctionNameMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHarmonyMessageHandler creates a new message handler
|
// NewHarmonyMessageHandler creates a new message handler
|
||||||
@@ -288,6 +290,7 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
|||||||
MessageEndTag: "<|end|>",
|
MessageEndTag: "<|end|>",
|
||||||
HeaderEndTag: "<|message|>",
|
HeaderEndTag: "<|message|>",
|
||||||
},
|
},
|
||||||
|
functionNameMap: NewFunctionNameMap(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -378,3 +381,97 @@ func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
|
|||||||
func (a *HarmonyToolCallAccumulator) Content() string {
|
func (a *HarmonyToolCallAccumulator) Content() string {
|
||||||
return a.acc.String()
|
return a.acc.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FunctionNameMap maps a user-specified function name to a valid function
|
||||||
|
// name for harmony (which look like TypeScript identifiers). This is needed to
|
||||||
|
// transform user-specified function names, which might contain characters that
|
||||||
|
// are not allowed in TypeScript identifiers
|
||||||
|
type FunctionNameMap struct {
|
||||||
|
userToHarmony map[string]string
|
||||||
|
harmonyToUser map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFunctionNameMap() *FunctionNameMap {
|
||||||
|
return &FunctionNameMap{
|
||||||
|
userToHarmony: make(map[string]string),
|
||||||
|
harmonyToUser: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||||
|
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||||
|
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||||
|
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||||
|
return harmonyFunctionName
|
||||||
|
}
|
||||||
|
|
||||||
|
// OriginalFromConverted looks up the reverse-mapping of a previously-converted
|
||||||
|
// user->harmony function name. To unmap reliably, the mapping must exist, as
|
||||||
|
// the conversion process is not reversible without the appropriate state
|
||||||
|
func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string {
|
||||||
|
if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok {
|
||||||
|
return userFunctionName
|
||||||
|
}
|
||||||
|
slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName)
|
||||||
|
// fallback to the original function name if we can't find a mapping
|
||||||
|
return harmonyFunctionName
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToValidChars converts a user-specified function name to a valid
|
||||||
|
// TypeScript identifier.
|
||||||
|
//
|
||||||
|
// Limitations:
|
||||||
|
//
|
||||||
|
// - This doesn't restrict reserved TypeScript keywords.
|
||||||
|
// - We don't perform a real ID_Start/ID_Continue check, and instead use the more
|
||||||
|
// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of
|
||||||
|
// identifiers these models were trained on, so in the end we might want to
|
||||||
|
// convert unicode-heavy identifiers to their closest ASCII equivalents.
|
||||||
|
func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string {
|
||||||
|
mapper := func(r rune) rune {
|
||||||
|
// first, replace certain characters with underscores
|
||||||
|
if r == ' ' || r == '-' || r == '.' {
|
||||||
|
return '_'
|
||||||
|
}
|
||||||
|
|
||||||
|
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// finally, remove any other characters
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
candidate := strings.Map(mapper, userFunctionName)
|
||||||
|
|
||||||
|
// set a default name if we end up with nothing left
|
||||||
|
if candidate == "" {
|
||||||
|
return "unnamed"
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the candidate starts with a number, prepend an underscore to make it a
|
||||||
|
// valid identifier
|
||||||
|
if unicode.IsDigit(rune(candidate[0])) {
|
||||||
|
candidate = "_" + candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FunctionNameMap) deriveName(userFunctionName string) string {
|
||||||
|
originalCandidate := m.convertToValidChars(userFunctionName)
|
||||||
|
candidate := originalCandidate
|
||||||
|
|
||||||
|
// Check for dupes, and if so, add a number to the end.
|
||||||
|
// We start at 2 because if we have dupes and the first is never renamed, it
|
||||||
|
// makes sense for them to be named, say, `f`, `f_2`, `f_3`
|
||||||
|
count := 2
|
||||||
|
for {
|
||||||
|
if _, exists := m.harmonyToUser[candidate]; !exists {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
candidate = fmt.Sprintf("%s_%d", originalCandidate, count)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
@@ -467,3 +467,71 @@ func TestHarmonyParserStreaming(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't
|
||||||
|
// handle any saving (and therefore no dupe handling)
|
||||||
|
func TestFunctionConvertToValidChars(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "replace spaces with underscores", in: "get weather", want: "get_weather"},
|
||||||
|
{name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"},
|
||||||
|
{name: "replace periods with underscores", in: "get.weather", want: "get_weather"},
|
||||||
|
{name: "disallow non-word characters", in: "get weather!", want: "get_weather"},
|
||||||
|
{name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"},
|
||||||
|
{name: "names that only contain invalid characters", in: "🫠", want: "unnamed"},
|
||||||
|
{name: "leading number", in: "123", want: "_123"},
|
||||||
|
{name: "$ allowed", in: "$", want: "$"},
|
||||||
|
// show that we allow weird unicode letter characters, though we might want
|
||||||
|
// to convert them to their closest ASCII equivalents in the future
|
||||||
|
{name: "allow weird unicode letter characters", in: "𝓸𝓵𝓵𝓪𝓶𝓪", want: "𝓸𝓵𝓵𝓪𝓶𝓪"},
|
||||||
|
// names that look like words but are invalid (i.e., not ID_Start/ID_Continue)
|
||||||
|
{name: "disallow non-word characters that look like words", in: "ⓞⓛⓛⓐⓜⓐ123", want: "_123"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parser := NewFunctionNameMap()
|
||||||
|
got := parser.convertToValidChars(tt.in)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("case %d: got %q, want %q", i, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionConvertAndAdd(t *testing.T) {
|
||||||
|
// make a fresh map for each test, but within a test use the same map so we can test for dupe handling
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}},
|
||||||
|
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||||
|
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||||
|
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
parser := NewFunctionNameMap()
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
for j, in := range tt.in {
|
||||||
|
got := parser.ConvertAndAdd(in)
|
||||||
|
want := tt.want[j]
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("case %d: got %q, want %q", i, got, want)
|
||||||
|
}
|
||||||
|
// check that the maps are correct
|
||||||
|
if parser.userToHarmony[in] != want {
|
||||||
|
t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want)
|
||||||
|
}
|
||||||
|
if parser.harmonyToUser[want] != in {
|
||||||
|
t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -1603,7 +1603,31 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
msgs = filterThinkTags(msgs, m)
|
msgs = filterThinkTags(msgs, m)
|
||||||
|
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
|
var harmonyMessageHandler *HarmonyMessageHandler
|
||||||
|
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||||
|
|
||||||
|
useHarmony := shouldUseHarmony(*m)
|
||||||
|
|
||||||
|
processedTools := req.Tools
|
||||||
|
if useHarmony {
|
||||||
|
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||||
|
var lastMessage *api.Message
|
||||||
|
if len(msgs) > 0 {
|
||||||
|
lastMessage = &msgs[len(msgs)-1]
|
||||||
|
}
|
||||||
|
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||||
|
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||||
|
|
||||||
|
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||||
|
// renamed to be valid Harmony function names.
|
||||||
|
processedTools = make([]api.Tool, len(req.Tools))
|
||||||
|
copy(processedTools, req.Tools)
|
||||||
|
for i, tool := range processedTools {
|
||||||
|
processedTools[i].Function.Name = harmonyMessageHandler.functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("chat prompt error", "error", err)
|
slog.Error("chat prompt error", "error", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -1623,27 +1647,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(*m)
|
|
||||||
|
|
||||||
// Validate Think value: string values currently only allowed for gptoss models
|
// Validate Think value: string values currently only allowed for gptoss models
|
||||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var harmonyMessageHandler *HarmonyMessageHandler
|
|
||||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
|
||||||
|
|
||||||
if useHarmony {
|
|
||||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
|
||||||
var lastMessage *api.Message
|
|
||||||
if len(msgs) > 0 {
|
|
||||||
lastMessage = &msgs[len(msgs)-1]
|
|
||||||
}
|
|
||||||
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
|
||||||
}
|
|
||||||
|
|
||||||
var thinkingState *thinking.Parser
|
var thinkingState *thinking.Parser
|
||||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||||
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
||||||
@@ -1696,6 +1705,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
toolName, toolContent := harmonyToolParser.Drain()
|
toolName, toolContent := harmonyToolParser.Drain()
|
||||||
if toolName != nil {
|
if toolName != nil {
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||||
|
*toolName = harmonyMessageHandler.functionNameMap.OriginalFromConverted(*toolName)
|
||||||
var args api.ToolCallFunctionArguments
|
var args api.ToolCallFunctionArguments
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||||
|
Reference in New Issue
Block a user