diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 14d6ba66c..e7be8e402 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -10,6 +10,7 @@ import ( "log/slog" "os" "slices" + "strings" "golang.org/x/exp/maps" ) @@ -60,7 +61,25 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) addedTokens[t.Content] = t } - t.Merges = tt.Model.Merges + if len(tt.Model.Merges) == 0 { + // noop; merges is empty + } else if err := json.Unmarshal(tt.Model.Merges, &t.Merges); err == nil { + // noop; merges is []string + } else if merges, err := func() ([][]string, error) { + var merges [][]string + if err := json.Unmarshal(tt.Model.Merges, &merges); err != nil { + return nil, err + } + + return merges, nil + }(); err == nil { + t.Merges = make([]string, len(merges)) + for i := range merges { + t.Merges[i] = strings.Join(merges[i], " ") + } + } else { + return nil, fmt.Errorf("could not parse tokenizer merges. expected []string or [][]string: %w", err) + } sha256sum := sha256.New() for _, pt := range tt.PreTokenizer.PreTokenizers { @@ -156,9 +175,9 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) type tokenizer struct { AddedTokens []token `json:"added_tokens"` Model struct { - Type string `json:"type"` - Vocab map[string]int `json:"vocab"` - Merges []string `json:"merges"` + Type string `json:"type"` + Vocab map[string]int `json:"vocab"` + Merges json.RawMessage `json:"merges"` } `json:"model"` PreTokenizer struct { diff --git a/convert/tokenizer_test.go b/convert/tokenizer_test.go index d9550e095..c6ef9732f 100644 --- a/convert/tokenizer_test.go +++ b/convert/tokenizer_test.go @@ -191,6 +191,62 @@ func TestParseTokenizer(t *testing.T) { Pre: "default", }, }, + { + name: "list string merges", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "model": { + "merges": [ + "a b", + "c d", + "e f" + ] + } + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + }, + Merges: []string{ + "a b", + "c d", + "e f", + }, + Pre: "default", + }, + }, + { + name: "list list string merges", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "model": { + "merges": [ + [ + "a", "b" + ], + [ + "c", "d" + ], + [ + "e", "f" + ] + ] + } + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + }, + Merges: []string{ + "a b", + "c d", + "e f", + }, + Pre: "default", + }, + }, } for _, tt := range cases {