diff --git a/server/harmonyparser.go b/server/harmonyparser.go index 86dcf66e34..4405cea440 100644 --- a/server/harmonyparser.go +++ b/server/harmonyparser.go @@ -2,6 +2,7 @@ package server import ( "context" + "fmt" "log/slog" "slices" "strings" @@ -275,8 +276,9 @@ const ( // HarmonyMessageHandler processes harmony events and accumulates content appropriately. // This is a higher level interface that maps harmony concepts into ollama concepts type HarmonyMessageHandler struct { - state harmonyMessageState - harmonyParser *HarmonyParser + state harmonyMessageState + harmonyParser *HarmonyParser + functionNameMap *FunctionNameMap } // NewHarmonyMessageHandler creates a new message handler @@ -288,6 +290,7 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler { MessageEndTag: "<|end|>", HeaderEndTag: "<|message|>", }, + functionNameMap: NewFunctionNameMap(), } } @@ -378,3 +381,97 @@ func (a *HarmonyToolCallAccumulator) Drain() (*string, string) { func (a *HarmonyToolCallAccumulator) Content() string { return a.acc.String() } + +// FunctionNameMap maps a user-specified function name to a valid function +// name for harmony (which look like TypeScript identifiers). This is needed to +// transform user-specified function names, which might contain characters that +// are not allowed in TypeScript identifiers +type FunctionNameMap struct { + userToHarmony map[string]string + harmonyToUser map[string]string +} + +func NewFunctionNameMap() *FunctionNameMap { + return &FunctionNameMap{ + userToHarmony: make(map[string]string), + harmonyToUser: make(map[string]string), + } +} + +func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string { + harmonyFunctionName := m.deriveName(userFunctionName) + m.userToHarmony[userFunctionName] = harmonyFunctionName + m.harmonyToUser[harmonyFunctionName] = userFunctionName + return harmonyFunctionName +} + +// OriginalFromConverted looks up the reverse-mapping of a previously-converted +// user->harmony function name. To unmap reliably, the mapping must exist, as +// the conversion process is not reversible without the appropriate state +func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string { + if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok { + return userFunctionName + } + slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName) + // fallback to the original function name if we can't find a mapping + return harmonyFunctionName +} + +// convertToValidChars converts a user-specified function name to a valid +// TypeScript identifier. +// +// Limitations: +// +// - This doesn't restrict reserved TypeScript keywords. +// - We don't perform a real ID_Start/ID_Continue check, and instead use the more +// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of +// identifiers these models were trained on, so in the end we might want to +// convert unicode-heavy identifiers to their closest ASCII equivalents. +func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string { + mapper := func(r rune) rune { + // first, replace certain characters with underscores + if r == ' ' || r == '-' || r == '.' { + return '_' + } + + if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' { + return r + } + + // finally, remove any other characters + return -1 + } + candidate := strings.Map(mapper, userFunctionName) + + // set a default name if we end up with nothing left + if candidate == "" { + return "unnamed" + } + + // if the candidate starts with a number, prepend an underscore to make it a + // valid identifier + if unicode.IsDigit(rune(candidate[0])) { + candidate = "_" + candidate + } + + return candidate +} + +func (m *FunctionNameMap) deriveName(userFunctionName string) string { + originalCandidate := m.convertToValidChars(userFunctionName) + candidate := originalCandidate + + // Check for dupes, and if so, add a number to the end. + // We start at 2 because if we have dupes and the first is never renamed, it + // makes sense for them to be named, say, `f`, `f_2`, `f_3` + count := 2 + for { + if _, exists := m.harmonyToUser[candidate]; !exists { + break + } + candidate = fmt.Sprintf("%s_%d", originalCandidate, count) + count++ + } + + return candidate +} diff --git a/server/harmonyparser_test.go b/server/harmonyparser_test.go index cd1743e1c6..8a22f34041 100644 --- a/server/harmonyparser_test.go +++ b/server/harmonyparser_test.go @@ -467,3 +467,71 @@ func TestHarmonyParserStreaming(t *testing.T) { }) } } + +// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't +// handle any saving (and therefore no dupe handling) +func TestFunctionConvertToValidChars(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {name: "replace spaces with underscores", in: "get weather", want: "get_weather"}, + {name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"}, + {name: "replace periods with underscores", in: "get.weather", want: "get_weather"}, + {name: "disallow non-word characters", in: "get weather!", want: "get_weather"}, + {name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"}, + {name: "names that only contain invalid characters", in: "🫠", want: "unnamed"}, + {name: "leading number", in: "123", want: "_123"}, + {name: "$ allowed", in: "$", want: "$"}, + // show that we allow weird unicode letter characters, though we might want + // to convert them to their closest ASCII equivalents in the future + {name: "allow weird unicode letter characters", in: "𝓸𝓡𝓡π“ͺ𝓢π“ͺ", want: "𝓸𝓡𝓡π“ͺ𝓢π“ͺ"}, + // names that look like words but are invalid (i.e., not ID_Start/ID_Continue) + {name: "disallow non-word characters that look like words", in: "β“žβ“›β“›β“β“œβ“123", want: "_123"}, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := NewFunctionNameMap() + got := parser.convertToValidChars(tt.in) + if got != tt.want { + t.Errorf("case %d: got %q, want %q", i, got, tt.want) + } + }) + } +} + +func TestFunctionConvertAndAdd(t *testing.T) { + // make a fresh map for each test, but within a test use the same map so we can test for dupe handling + tests := []struct { + name string + in []string + want []string + }{ + {name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}}, + {name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}}, + {name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}}, + {name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}}, + } + + for i, tt := range tests { + parser := NewFunctionNameMap() + t.Run(tt.name, func(t *testing.T) { + for j, in := range tt.in { + got := parser.ConvertAndAdd(in) + want := tt.want[j] + if got != want { + t.Errorf("case %d: got %q, want %q", i, got, want) + } + // check that the maps are correct + if parser.userToHarmony[in] != want { + t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want) + } + if parser.harmonyToUser[want] != in { + t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in) + } + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index 3b94daad08..60b7e3e841 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1603,7 +1603,31 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think) + var harmonyMessageHandler *HarmonyMessageHandler + var harmonyToolParser *HarmonyToolCallAccumulator + + useHarmony := shouldUseHarmony(*m) + + processedTools := req.Tools + if useHarmony { + harmonyMessageHandler = NewHarmonyMessageHandler() + var lastMessage *api.Message + if len(msgs) > 0 { + lastMessage = &msgs[len(msgs)-1] + } + harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage) + harmonyToolParser = harmonyMessageHandler.CreateToolParser() + + // make a copy of tools to pass to the chat prompt. Function names may be + // renamed to be valid Harmony function names. + processedTools = make([]api.Tool, len(req.Tools)) + copy(processedTools, req.Tools) + for i, tool := range processedTools { + processedTools[i].Function.Name = harmonyMessageHandler.functionNameMap.ConvertAndAdd(tool.Function.Name) + } + } + + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think) if err != nil { slog.Error("chat prompt error", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -1623,27 +1647,12 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - useHarmony := shouldUseHarmony(*m) - // Validate Think value: string values currently only allowed for gptoss models if req.Think != nil && req.Think.IsString() && !useHarmony { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) return } - var harmonyMessageHandler *HarmonyMessageHandler - var harmonyToolParser *HarmonyToolCallAccumulator - - if useHarmony { - harmonyMessageHandler = NewHarmonyMessageHandler() - var lastMessage *api.Message - if len(msgs) > 0 { - lastMessage = &msgs[len(msgs)-1] - } - harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - } - var thinkingState *thinking.Parser openingTag, closingTag := thinking.InferTags(m.Template.Template) if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" { @@ -1696,6 +1705,7 @@ func (s *Server) ChatHandler(c *gin.Context) { toolName, toolContent := harmonyToolParser.Drain() if toolName != nil { *toolName = strings.TrimPrefix(*toolName, "functions.") + *toolName = harmonyMessageHandler.functionNameMap.OriginalFromConverted(*toolName) var args api.ToolCallFunctionArguments if err := json.Unmarshal([]byte(toolContent), &args); err != nil { errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())