mirror of
https://github.com/ollama/ollama.git
synced 2025-08-25 09:51:25 +02:00
harmony: move harmony parsing into a package (#12016)
This commit is contained in:
464
harmony/harmonyparser.go
Normal file
464
harmony/harmonyparser.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type harmonyParserState int
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
harmonyParserState_ParsingContent
|
||||
)
|
||||
|
||||
func (s harmonyParserState) String() string {
|
||||
switch s {
|
||||
// we're looking for the message start tag
|
||||
case harmonyParserState_LookingForMessageStart:
|
||||
return "LookingForMessageStart"
|
||||
case harmonyParserState_ParsingHeader:
|
||||
return "ParsingHeader"
|
||||
case harmonyParserState_ParsingContent:
|
||||
return "ParsingContent"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type HarmonyParser struct {
|
||||
state harmonyParserState
|
||||
MessageStartTag string
|
||||
MessageEndTag string
|
||||
HeaderEndTag string
|
||||
acc strings.Builder
|
||||
lifetimeAcc strings.Builder
|
||||
}
|
||||
|
||||
type HarmonyEvent interface {
|
||||
isHarmonyEvent()
|
||||
}
|
||||
|
||||
type HarmonyEventMessageStart struct{}
|
||||
|
||||
func (HarmonyEventMessageStart) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventHeaderComplete struct {
|
||||
Header HarmonyHeader
|
||||
}
|
||||
|
||||
func (HarmonyEventHeaderComplete) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventContentEmitted struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
func (HarmonyEventContentEmitted) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyEventMessageEnd struct{}
|
||||
|
||||
func (HarmonyEventMessageEnd) isHarmonyEvent() {}
|
||||
|
||||
type HarmonyHeader struct {
|
||||
Role string
|
||||
Channel string
|
||||
Recipient string
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStart() {
|
||||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
// handle prefilling conditions
|
||||
if lastMessage.Content != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||
return
|
||||
} else if lastMessage.Thinking != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||
return
|
||||
}
|
||||
}
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
||||
s.lifetimeAcc.WriteString(content)
|
||||
s.acc.WriteString(content)
|
||||
|
||||
var events []HarmonyEvent
|
||||
|
||||
keepLooping := true
|
||||
// we loop because we might pass through multiple parsing states in a single
|
||||
// call to addContent, and we want to make sure callers don't have to wait for
|
||||
// data that's already unambiguous
|
||||
for keepLooping {
|
||||
var newEvents []HarmonyEvent
|
||||
newEvents, keepLooping = eat(s)
|
||||
events = append(events, newEvents...)
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// the additional bool return is true iff we should continue eating
|
||||
func eat(s *HarmonyParser) ([]HarmonyEvent, bool) {
|
||||
switch s.state {
|
||||
case harmonyParserState_LookingForMessageStart:
|
||||
// does the acc contain the message start tag?
|
||||
if strings.Contains(s.acc.String(), s.MessageStartTag) {
|
||||
// split the acc into the message start tag and the rest
|
||||
split := strings.SplitN(s.acc.String(), s.MessageStartTag, 2)
|
||||
before := split[0]
|
||||
if before != "" {
|
||||
slog.Warn("harmony parser: found message start tag in the middle of the content", "content", s.acc.String())
|
||||
}
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_ParsingHeader
|
||||
return []HarmonyEvent{HarmonyEventMessageStart{}}, true
|
||||
}
|
||||
|
||||
// no match, so we keep accumulating
|
||||
return nil, false
|
||||
case harmonyParserState_ParsingHeader:
|
||||
if strings.Contains(s.acc.String(), s.HeaderEndTag) {
|
||||
split := strings.SplitN(s.acc.String(), s.HeaderEndTag, 2)
|
||||
header := split[0]
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_ParsingContent
|
||||
return []HarmonyEvent{HarmonyEventHeaderComplete{Header: s.parseHeader(header)}}, true
|
||||
}
|
||||
return nil, false
|
||||
case harmonyParserState_ParsingContent:
|
||||
if strings.Contains(s.acc.String(), s.MessageEndTag) {
|
||||
// if we already have the message end tag, we can emit the content up to it
|
||||
split := strings.SplitN(s.acc.String(), s.MessageEndTag, 2)
|
||||
content := split[0]
|
||||
after := split[1]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
s.state = harmonyParserState_LookingForMessageStart
|
||||
events := []HarmonyEvent{}
|
||||
if content != "" {
|
||||
events = append(events, HarmonyEventContentEmitted{Content: content})
|
||||
}
|
||||
events = append(events, HarmonyEventMessageEnd{})
|
||||
return events, true
|
||||
} else if overlapLen := overlap(s.acc.String(), s.MessageEndTag); overlapLen > 0 {
|
||||
// if our suffix contains the start of the message end tag, we can emit
|
||||
// the content up to the start of the message end tag
|
||||
content := s.acc.String()[:len(s.acc.String())-overlapLen]
|
||||
remaining := s.acc.String()[len(s.acc.String())-overlapLen:]
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(remaining)
|
||||
// emit the content we know isn't part of the message end tag, and keep
|
||||
// accumulating to disambiguate the rest
|
||||
if content == "" {
|
||||
return nil, false
|
||||
}
|
||||
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||
} else {
|
||||
// no end tag, so it's still normal content that we can immediately emit
|
||||
content := s.acc.String()
|
||||
if content == "" {
|
||||
return nil, false
|
||||
}
|
||||
s.acc.Reset()
|
||||
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader {
|
||||
harmonyHeader := HarmonyHeader{}
|
||||
|
||||
// if `<|constrain|>` is present, ensure it has a space before it so it gets
|
||||
// parsed as a separate token, even if the model didn't include the space
|
||||
if strings.Contains(raw, "<|constrain|>") {
|
||||
raw = strings.Replace(raw, "<|constrain|>", " <|constrain|>", 1)
|
||||
raw = strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
// look for the optional channel tag, which is `<|channel|>` followed by the
|
||||
// channel name, all without any whitespace
|
||||
channelIndex := strings.Index(raw, "<|channel|>")
|
||||
if channelIndex != -1 {
|
||||
before := raw[:channelIndex]
|
||||
after := raw[channelIndex+len("<|channel|>"):]
|
||||
// the channel name is `after` all the way up to the first (if any) whitespace character
|
||||
idx := strings.IndexFunc(after, func(r rune) bool {
|
||||
return unicode.IsSpace(r)
|
||||
})
|
||||
if idx == -1 {
|
||||
idx = len(after)
|
||||
}
|
||||
harmonyHeader.Channel = after[:idx]
|
||||
after = after[idx:]
|
||||
// now we remove the channel tag from the raw string to further process
|
||||
raw = before + after
|
||||
raw = strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
// split the header into whitespace-separated tokens
|
||||
tokens := strings.Fields(raw)
|
||||
|
||||
// the first token is treated as the role
|
||||
if len(tokens) == 0 {
|
||||
slog.Error("harmony parser: missing role in header", "header", raw)
|
||||
return harmonyHeader
|
||||
}
|
||||
role := tokens[0]
|
||||
tokens = tokens[1:]
|
||||
// special case: if role starts with to= then it's a tool call
|
||||
if strings.HasPrefix(role, "to=") {
|
||||
harmonyHeader.Recipient = role[3:]
|
||||
harmonyHeader.Role = "tool"
|
||||
} else {
|
||||
harmonyHeader.Role = role
|
||||
}
|
||||
|
||||
// the recipient (if any) can be specified before or after the channel tag, so
|
||||
// we check it at the end once we've already parsed the channel and role
|
||||
if harmonyHeader.Recipient == "" && len(tokens) > 0 && strings.HasPrefix(tokens[0], "to=") {
|
||||
harmonyHeader.Recipient = tokens[0][3:]
|
||||
}
|
||||
|
||||
return harmonyHeader
|
||||
}
|
||||
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// harmonyMessageState represents the current state of message processing
|
||||
type harmonyMessageState int
|
||||
|
||||
const (
|
||||
harmonyMessageState_Normal harmonyMessageState = iota
|
||||
harmonyMessageState_Thinking
|
||||
harmonyMessageState_ToolCalling
|
||||
)
|
||||
|
||||
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||
type HarmonyMessageHandler struct {
|
||||
state harmonyMessageState
|
||||
HarmonyParser *HarmonyParser
|
||||
FunctionNameMap *FunctionNameMap
|
||||
}
|
||||
|
||||
// NewHarmonyMessageHandler creates a new message handler
|
||||
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||
return &HarmonyMessageHandler{
|
||||
state: harmonyMessageState_Normal,
|
||||
HarmonyParser: &HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
},
|
||||
FunctionNameMap: NewFunctionNameMap(),
|
||||
}
|
||||
}
|
||||
|
||||
// AddContent processes the content and returns the content, thinking, and tool content.
|
||||
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
|
||||
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
|
||||
contentSb := strings.Builder{}
|
||||
thinkingSb := strings.Builder{}
|
||||
toolContentSb := strings.Builder{}
|
||||
|
||||
events := h.HarmonyParser.AddContent(content)
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case HarmonyEventHeaderComplete:
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header)
|
||||
switch event.Header.Channel {
|
||||
case "analysis":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
// event.Header.Recipient is the tool name, something like
|
||||
// "browser.search" for a built-in, or "functions.calc" for a
|
||||
// custom one
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Thinking
|
||||
}
|
||||
case "commentary":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case "final":
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case HarmonyEventContentEmitted:
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state)
|
||||
if h.state == harmonyMessageState_Normal {
|
||||
contentSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_Thinking {
|
||||
thinkingSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_ToolCalling {
|
||||
toolContentSb.WriteString(event.Content)
|
||||
}
|
||||
case HarmonyEventMessageEnd:
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
}
|
||||
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
|
||||
}
|
||||
|
||||
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
|
||||
return &HarmonyToolCallAccumulator{
|
||||
state: harmonyToolCallState_Normal,
|
||||
currentToolName: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type harmonyToolCallState int
|
||||
|
||||
const (
|
||||
harmonyToolCallState_Normal harmonyToolCallState = iota
|
||||
harmonyToolCallState_ToolCalling
|
||||
)
|
||||
|
||||
type HarmonyToolCallAccumulator struct {
|
||||
state harmonyToolCallState
|
||||
acc strings.Builder
|
||||
currentToolName *string
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) SetToolName(toolName string) {
|
||||
a.currentToolName = &toolName
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Add(content string) {
|
||||
a.acc.WriteString(content)
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
|
||||
str := a.acc.String()
|
||||
a.state = harmonyToolCallState_Normal
|
||||
a.acc.Reset()
|
||||
return a.currentToolName, str
|
||||
}
|
||||
|
||||
func (a *HarmonyToolCallAccumulator) Content() string {
|
||||
return a.acc.String()
|
||||
}
|
||||
|
||||
// FunctionNameMap maps a user-specified function name to a valid function
|
||||
// name for harmony (which look like TypeScript identifiers). This is needed to
|
||||
// transform user-specified function names, which might contain characters that
|
||||
// are not allowed in TypeScript identifiers
|
||||
type FunctionNameMap struct {
|
||||
userToHarmony map[string]string
|
||||
harmonyToUser map[string]string
|
||||
}
|
||||
|
||||
func NewFunctionNameMap() *FunctionNameMap {
|
||||
return &FunctionNameMap{
|
||||
userToHarmony: make(map[string]string),
|
||||
harmonyToUser: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// OriginalFromConverted looks up the reverse-mapping of a previously-converted
|
||||
// user->harmony function name. To unmap reliably, the mapping must exist, as
|
||||
// the conversion process is not reversible without the appropriate state
|
||||
func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string {
|
||||
if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok {
|
||||
return userFunctionName
|
||||
}
|
||||
slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName)
|
||||
// fallback to the original function name if we can't find a mapping
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// convertToValidChars converts a user-specified function name to a valid
|
||||
// TypeScript identifier.
|
||||
//
|
||||
// Limitations:
|
||||
//
|
||||
// - This doesn't restrict reserved TypeScript keywords.
|
||||
// - We don't perform a real ID_Start/ID_Continue check, and instead use the more
|
||||
// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of
|
||||
// identifiers these models were trained on, so in the end we might want to
|
||||
// convert unicode-heavy identifiers to their closest ASCII equivalents.
|
||||
func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string {
|
||||
mapper := func(r rune) rune {
|
||||
// first, replace certain characters with underscores
|
||||
if r == ' ' || r == '-' || r == '.' {
|
||||
return '_'
|
||||
}
|
||||
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||
return r
|
||||
}
|
||||
|
||||
// finally, remove any other characters
|
||||
return -1
|
||||
}
|
||||
candidate := strings.Map(mapper, userFunctionName)
|
||||
|
||||
// set a default name if we end up with nothing left
|
||||
if candidate == "" {
|
||||
return "unnamed"
|
||||
}
|
||||
|
||||
// if the candidate starts with a number, prepend an underscore to make it a
|
||||
// valid identifier
|
||||
if unicode.IsDigit(rune(candidate[0])) {
|
||||
candidate = "_" + candidate
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) deriveName(userFunctionName string) string {
|
||||
originalCandidate := m.convertToValidChars(userFunctionName)
|
||||
candidate := originalCandidate
|
||||
|
||||
// Check for dupes, and if so, add a number to the end.
|
||||
// We start at 2 because if we have dupes and the first is never renamed, it
|
||||
// makes sense for them to be named, say, `f`, `f_2`, `f_3`
|
||||
count := 2
|
||||
for {
|
||||
if _, exists := m.harmonyToUser[candidate]; !exists {
|
||||
break
|
||||
}
|
||||
candidate = fmt.Sprintf("%s_%d", originalCandidate, count)
|
||||
count++
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
537
harmony/harmonyparser_test.go
Normal file
537
harmony/harmonyparser_test.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHeaderParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantRole, wantChannel, wantRecipient string
|
||||
}{
|
||||
{
|
||||
in: "assistant<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "assistant<|channel|>analysis to=functions.get_weather",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// special case where the role is replaced by the recipient (matches reference code)
|
||||
{
|
||||
in: "to=functions.get_weather<|channel|>analysis",
|
||||
wantRole: "tool",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// extra token after the recipient is ignored
|
||||
{
|
||||
in: "assistant to=functions.get_weather abc<|channel|>analysis",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// with constrain tag, recipient after channel tag
|
||||
{
|
||||
in: "assistant<|channel|>commentary to=functions.get_weather <|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// with constrain tag, recipient before channel tag
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>commentary <|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// constrain tag without space
|
||||
{
|
||||
in: "assistant<|channel|>commentary to=functions.get_weather<|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
// constrain tag without space, different order
|
||||
{
|
||||
in: "assistant to=functions.get_weather<|channel|>commentary<|constrain|>json",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
header := parser.parseHeader(tt.in)
|
||||
|
||||
if header.Role != tt.wantRole {
|
||||
t.Errorf("case %d: got role \"%s\", want \"%s\"", i, header.Role, tt.wantRole)
|
||||
}
|
||||
if header.Channel != tt.wantChannel {
|
||||
t.Errorf("case %d: got channel \"%s\", want \"%s\"", i, header.Channel, tt.wantChannel)
|
||||
}
|
||||
if header.Recipient != tt.wantRecipient {
|
||||
t.Errorf("case %d: got recipient \"%s\", want \"%s\"", i, header.Recipient, tt.wantRecipient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserHeaderEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantRole, wantChannel, wantRecipient string
|
||||
implicitStart bool
|
||||
}{
|
||||
{
|
||||
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||
wantRole: "user",
|
||||
wantChannel: "",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>analysis<|message|>What is 2 + 2?<|end|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{\"location\":\"San Francisco\"}<|call|><|start|>functions.get_weather to=assistant<|message|>{\"sunny\": true, \"temperature\": 20}<|end|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "commentary",
|
||||
wantRecipient: "functions.get_weather",
|
||||
},
|
||||
{
|
||||
in: "<|channel|>analysis<|message|>User asks weather in SF. We need location. Use get_current_weather with location \"San Francisco, CA\".<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\"location\":\"San Francisco, CA\"}<|call|>",
|
||||
wantRole: "assistant",
|
||||
wantChannel: "analysis",
|
||||
wantRecipient: "",
|
||||
implicitStart: true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tt.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
gotEvents := parser.AddContent(tt.in)
|
||||
if len(gotEvents) == 0 {
|
||||
t.Errorf("case %d: got no events, want at least one", i)
|
||||
}
|
||||
|
||||
var firstHeaderEvent *HarmonyEventHeaderComplete
|
||||
// print events
|
||||
for _, event := range gotEvents {
|
||||
fmt.Printf("event: %+v\n", event)
|
||||
}
|
||||
for _, event := range gotEvents {
|
||||
if event, ok := event.(HarmonyEventHeaderComplete); ok {
|
||||
firstHeaderEvent = &event
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if firstHeaderEvent == nil {
|
||||
t.Errorf("case %d: got no header complete event, want one", i)
|
||||
continue
|
||||
}
|
||||
gotHeader := firstHeaderEvent.Header
|
||||
if gotHeader.Role != tt.wantRole || gotHeader.Channel != tt.wantChannel || gotHeader.Recipient != tt.wantRecipient {
|
||||
t.Errorf("case %d: got header %+v, want role=%s channel=%s recipient=%s", i, gotHeader, tt.wantRole, tt.wantChannel, tt.wantRecipient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserNonStreaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
implicitStart bool
|
||||
wantEvents []HarmonyEvent
|
||||
}{
|
||||
{
|
||||
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "What is 2 + 2?"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>analysis<|message|>The answer is 4<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "The answer is 4"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>Computing...<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
HarmonyEventContentEmitted{Content: "Computing..."},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>user<|message|><|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|start|>user<|message|>Hello<|end|><|start|>assistant<|message|>Hi!<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hello"},
|
||||
HarmonyEventMessageEnd{},
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hi!"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
in: "<|channel|>analysis<|message|>Thinking about the request<|end|>",
|
||||
implicitStart: true,
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}, HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}, HarmonyEventContentEmitted{Content: "Thinking about the request"}, HarmonyEventMessageEnd{}},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tt.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
gotEvents := parser.AddContent(tt.in)
|
||||
if !reflect.DeepEqual(gotEvents, tt.wantEvents) {
|
||||
t.Errorf("case %d: got events %#v, want %#v", i, gotEvents, tt.wantEvents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []HarmonyEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
implicitStart bool
|
||||
steps []step
|
||||
}{
|
||||
{
|
||||
desc: "simple message streamed character by character",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "|",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "start|>u",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "ser<|mess",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "age|>Hi",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hi"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " there",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: " there"}},
|
||||
},
|
||||
{
|
||||
input: "<|e",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "nd|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel streamed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "<|chan",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "nel|>analysis",
|
||||
wantEvents: nil,
|
||||
},
|
||||
{
|
||||
input: "<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Thinking",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Thinking"}},
|
||||
},
|
||||
{
|
||||
input: "...",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "..."}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel and recipient",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "{\"x\": 5}",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "message with channel and recipient (receipient before channel)",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>assistant to=functions.calc<|channel|>commentary<|message|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "{\"x\": 5}",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "implicit start with channel",
|
||||
implicitStart: true,
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|channel|>thinking",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "thinking", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Processing request",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Processing request"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple messages streamed",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>user<|message|>Hello<|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "Hello"},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "<|start|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||
},
|
||||
{
|
||||
input: "assistant<|message|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}}},
|
||||
},
|
||||
{
|
||||
input: "Hi!",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Hi!"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty message",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>system<|message|><|end|>",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "system", Channel: "", Recipient: ""}},
|
||||
HarmonyEventMessageEnd{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tag that looks like end but isn't",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<|start|>user<|message|>test<|e",
|
||||
wantEvents: []HarmonyEvent{
|
||||
HarmonyEventMessageStart{},
|
||||
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||
HarmonyEventContentEmitted{Content: "test"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "xample|>more",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "<|example|>more"}},
|
||||
},
|
||||
{
|
||||
input: "<|end|>",
|
||||
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
if tc.implicitStart {
|
||||
parser.AddImplicitStart()
|
||||
}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
gotEvents := parser.AddContent(step.input)
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't
|
||||
// handle any saving (and therefore no dupe handling)
|
||||
func TestFunctionConvertToValidChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "replace spaces with underscores", in: "get weather", want: "get_weather"},
|
||||
{name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"},
|
||||
{name: "replace periods with underscores", in: "get.weather", want: "get_weather"},
|
||||
{name: "disallow non-word characters", in: "get weather!", want: "get_weather"},
|
||||
{name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"},
|
||||
{name: "names that only contain invalid characters", in: "🫠", want: "unnamed"},
|
||||
{name: "leading number", in: "123", want: "_123"},
|
||||
{name: "$ allowed", in: "$", want: "$"},
|
||||
// show that we allow weird unicode letter characters, though we might want
|
||||
// to convert them to their closest ASCII equivalents in the future
|
||||
{name: "allow weird unicode letter characters", in: "𝓸𝓵𝓵𝓪𝓶𝓪", want: "𝓸𝓵𝓵𝓪𝓶𝓪"},
|
||||
// names that look like words but are invalid (i.e., not ID_Start/ID_Continue)
|
||||
{name: "disallow non-word characters that look like words", in: "ⓞⓛⓛⓐⓜⓐ123", want: "_123"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := NewFunctionNameMap()
|
||||
got := parser.convertToValidChars(tt.in)
|
||||
if got != tt.want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionConvertAndAdd(t *testing.T) {
|
||||
// make a fresh map for each test, but within a test use the same map so we can test for dupe handling
|
||||
tests := []struct {
|
||||
name string
|
||||
in []string
|
||||
want []string
|
||||
}{
|
||||
{name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}},
|
||||
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
parser := NewFunctionNameMap()
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for j, in := range tt.in {
|
||||
got := parser.ConvertAndAdd(in)
|
||||
want := tt.want[j]
|
||||
if got != want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, want)
|
||||
}
|
||||
// check that the maps are correct
|
||||
if parser.userToHarmony[in] != want {
|
||||
t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want)
|
||||
}
|
||||
if parser.harmonyToUser[want] != in {
|
||||
t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user