From a40d427bcea52ad5c7e93780564fc15e5ef80473 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 23 Sep 2025 13:21:47 -0700 Subject: [PATCH] multi-regexp pretokenizer (#12325) --- model/bytepairencoding.go | 54 ++++++++++++++++++++++++++++------ model/bytepairencoding_test.go | 40 ++++++++++++++++++++++++- model/models/gptoss/model.go | 20 ++++++------- model/models/llama/model.go | 28 +++++++++++++++--- model/models/llama4/model.go | 3 +- model/models/mistral3/model.go | 2 +- model/models/mllama/model.go | 2 +- model/models/qwen2/model.go | 2 +- model/models/qwen25vl/model.go | 2 +- model/models/qwen3/embed.go | 2 +- model/models/qwen3/model.go | 2 +- sample/samplers_test.go | 1 - 12 files changed, 124 insertions(+), 34 deletions(-) diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index e21564aa53..3d51f70e81 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -5,6 +5,7 @@ import ( "fmt" "iter" "log/slog" + "slices" "strings" "github.com/dlclark/regexp2" @@ -13,16 +14,28 @@ import ( ) type BytePairEncoding struct { - pre *regexp2.Regexp - vocab *Vocabulary + vocab *Vocabulary + regexps []*regexp2.Regexp } var _ TextProcessor = (*BytePairEncoding)(nil) -func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { +func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding { + if len(pretokenizers) == 0 { + // set default byte-level pretokenizer if none provided, e.g. + // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44 + pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`} + } + return BytePairEncoding{ - pre: regexp2.MustCompile(pre, regexp2.None), vocab: vocab, + regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) { + for _, p := range pretokenizers { + if !yield(regexp2.MustCompile(p, regexp2.RE2)) { + return + } + } + }), } } @@ -35,13 +48,36 @@ func (bpe BytePairEncoding) Is(id int32, special Special) bool { } func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { - return func(yield func(string) bool) { - for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) { - if !yield(m.String()) { - break + parts := []string{s} + for _, re := range bpe.regexps { + parts = slices.Collect(func(yield func(string) bool) { + for _, part := range parts { + r := []rune(part) + var offset int + for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) { + if offset-m.Index != 0 { + if !yield(string(r[:m.Index])) { + return + } + } + + if !yield(m.String()) { + return + } + + offset = m.Index + m.Length + } + + if offset < len(r) { + if !yield(string(r[offset:])) { + return + } + } } - } + }) } + + return slices.Values(parts) } // fragment is a string fragment and their corresponding token IDs diff --git a/model/bytepairencoding_test.go b/model/bytepairencoding_test.go index 71947be993..39e5ab452c 100644 --- a/model/bytepairencoding_test.go +++ b/model/bytepairencoding_test.go @@ -59,12 +59,12 @@ func llama(t testing.TB) BytePairEncoding { } return NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &Vocabulary{ Values: tokens, Types: types, Merges: merges, }, + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", ) } @@ -282,3 +282,41 @@ func BenchmarkBytePairEncoding(b *testing.B) { }) } } + +func TestSplit(t *testing.T) { + cases := []struct { + name string + patterns, + want []string + }{ + { + name: "default", + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"}, + }, + { + name: "unicode", + patterns: []string{ + "\\p{N}{1,3}", + `[一-龥぀-ゟ゠-ヿ]+`, + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"}, + }, + { + name: "individual digits", + patterns: []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tokenizer := NewBytePairEncoding(nil, tt.patterns...) + if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" { + t.Errorf("no match (-theirs +ours):\n%s", diff) + } + }) + } +} diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 8456ea5f71..6a3270651b 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -227,17 +227,6 @@ func New(c fs.Config) (model.Model, error) { m := Transformer{ TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")), BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", - strings.Join([]string{ - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, - `\p{N}{1,3}`, - ` ?[^\s\p{L}\p{N}]+[\r\n/]*`, - `\s*[\r\n]+`, - `\s+(?!\S)`, - `\s+`, - }, "|"), - ), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -250,6 +239,15 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + strings.Join([]string{ + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `\p{N}{1,3}`, + ` ?[^\s\p{L}\p{N}]+[\r\n/]*`, + `\s*[\r\n]+`, + `\s+(?!\S)`, + `\s+`, + }, "|"), ), Options: Options{ hiddenSize: int(c.Uint("embedding_length")), diff --git a/model/models/llama/model.go b/model/models/llama/model.go index f6ec022738..c03f04a0d8 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -54,10 +54,30 @@ func New(c fs.Config) (model.Model, error) { } switch c.String("tokenizer.ggml.model") { case "gpt2": - processor = model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, - &vocabulary, - ) + var pretokenizers []string + switch c.String("tokenizer.ggml.pre") { + case "default": + // no-op use the default bpe pretokenizer + case "qwen2": + pretokenizers = []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + case "refact": + pretokenizers = []string{ + `\p{N}`, + `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`, + } + case "tekken": + pretokenizers = []string{ + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + default: + // use a llama-style pretokenizer + pretokenizers = []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + } + processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...) case "llama": processor = model.NewSentencePiece(&vocabulary) default: diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 9cb2efc87a..e80fbaed63 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -34,8 +34,6 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor { func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -48,6 +46,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 435b1a304d..5c46615e92 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -33,7 +33,6 @@ var _ model.TextProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { m := &Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), TextModel: newTextModel(c), VisionModel: newVisionModel(c), diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 239d999d50..769743694c 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -33,7 +33,6 @@ const ( func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 5a3458378e..2e2347102e 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -139,7 +139,6 @@ func New(c fs.Config) (model.Model, error) { m := Model{ Layers: make([]DecoderLayer, c.Uint("block_count")), BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -152,6 +151,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Options: Options{ hiddenSize: int(c.Uint("embedding_length")), diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 6c76305db8..6898e38cac 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -29,7 +29,6 @@ var _ model.MultimodalProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { m := &Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -42,6 +41,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), TextModel: NewTextModel(c), VisionModel: newVisionModel(c), diff --git a/model/models/qwen3/embed.go b/model/models/qwen3/embed.go index 9a77efea99..c03888d45c 100644 --- a/model/models/qwen3/embed.go +++ b/model/models/qwen3/embed.go @@ -35,7 +35,6 @@ func newEmbed(c fs.Config) (model.Model, error) { } m := embedModel{ BytePairEncoding: model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -48,6 +47,7 @@ func newEmbed(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Model: &Model{ Layers: layers, diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 352268347c..cc58e4a289 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -200,7 +200,6 @@ func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -213,6 +212,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Layers: layers, Options: &Options{ diff --git a/sample/samplers_test.go b/sample/samplers_test.go index b720f027c3..eb10295d45 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -82,7 +82,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding { merges := make([]string, 0, 1) // Only need vocab for Grammar Test return model.NewBytePairEncoding( - ``, &model.Vocabulary{ Values: tokens, Types: make([]int32, len(vocab)),