From 333e360422744e92275af2c1c2d5bc039ad97e8f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 16 May 2025 13:40:23 -0700 Subject: [PATCH] model: handle multiple eos tokens (#10577) * get eos_token_id from generation_config.json * refactor * include both ids and strings in trace * comments * remove special case for gemma3 special vocab (#10743) --- convert/convert.go | 5 +- convert/tokenizer.go | 32 +++++ convert/tokenizer_test.go | 61 +++++++++ llama/llama.go | 4 +- .../{process_text.go => bytepairencoding.go} | 126 +----------------- ..._text_test.go => bytepairencoding_test.go} | 0 model/models/gemma2/model.go | 11 +- model/models/gemma3/model.go | 12 +- model/models/llama/model.go | 10 +- model/models/llama4/model.go | 10 +- model/models/mistral3/model.go | 18 +-- model/models/mllama/model.go | 10 +- model/models/qwen25vl/model.go | 11 +- .../{process_text_spm.go => sentencepiece.go} | 23 +--- ...text_spm_test.go => sentencepiece_test.go} | 0 model/textprocessor.go | 17 +++ model/vocabulary.go | 112 ++++++++++++++++ sample/samplers.go | 2 +- 18 files changed, 282 insertions(+), 182 deletions(-) rename model/{process_text.go => bytepairencoding.go} (66%) rename model/{process_text_test.go => bytepairencoding_test.go} (100%) rename model/{process_text_spm.go => sentencepiece.go} (89%) rename model/{process_text_spm_test.go => sentencepiece_test.go} (100%) create mode 100644 model/textprocessor.go create mode 100644 model/vocabulary.go diff --git a/convert/convert.go b/convert/convert.go index 309b0ce19..4a6df66c7 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -53,8 +53,11 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV { } for _, sv := range t.SpecialVocabulary { - kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID) kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken + kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID) + if len(sv.IDs) > 0 { + kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs + } } return kv diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 74e2efed0..bedcd4f80 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -110,6 +110,7 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) } if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) { + // noop } else if err != nil { return nil, err } else { @@ -171,6 +172,34 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) } } + if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) { + } else if err != nil { + return nil, err + } else { + defer f.Close() + + var p map[string]json.RawMessage + if err := json.NewDecoder(f).Decode(&p); err != nil { + return nil, err + } + + for _, st := range specialTokenTypes { + if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok { + var ids []int32 + if err := json.Unmarshal(bts, &ids); err != nil { + // value is not a list so the existing ID is used + continue + } + + if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool { + return sv.Type == st + }); i >= 0 { + t.SpecialVocabulary[i].IDs = ids + } + } + } + } + return t, nil } @@ -280,6 +309,9 @@ type SpecialVocabulary struct { ID int Content string AddToken bool + + // IDs is populated by generation_config.json + IDs []int32 } func (sv SpecialVocabulary) Key() string { diff --git a/convert/tokenizer_test.go b/convert/tokenizer_test.go index c6ef9732f..813096fd9 100644 --- a/convert/tokenizer_test.go +++ b/convert/tokenizer_test.go @@ -247,6 +247,67 @@ func TestParseTokenizer(t *testing.T) { Pre: "default", }, }, + { + name: "generation config eos token ids", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "added_tokens": [ + { + "id": 0, + "content": "", + "special": true + }, + { + "id": 1, + "content": "", + "special": true + }, + { + "id": 2, + "content": "", + "special": true + }, + { + "id": 3, + "content": "", + "special": true + } + ], + "model": { + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3 + } + } + }`), + "tokenizer_config.json": strings.NewReader(`{ + "add_bos_token": true, + "add_eos_token": false, + "bos_token": "", + "eos_token": "" + }`), + "generation_config.json": strings.NewReader(`{ + "bos_token_id": 0, + "eos_token_id": [1, 2, 3] + }`), + }), + specialTokenTypes: []string{"pad", "eos", "bos", "unk"}, + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + Tokens: []string{"", "", "", ""}, + Scores: []float32{0, 1, 2, 3}, + Types: []int32{3, 3, 3, 3}, + }, + SpecialVocabulary: []*SpecialVocabulary{ + {Type: "eos", Content: "", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false}, + {Type: "bos", Content: "", ID: 0, AddToken: true}, + }, + Pre: "default", + }, + }, } for _, tt := range cases { diff --git a/llama/llama.go b/llama/llama.go index 1251be3a5..626ea13a3 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -602,7 +602,7 @@ type Grammar struct { mu sync.Mutex } -func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar { +func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []int32) *Grammar { cGrammar := C.CString(grammar) defer C.free(unsafe.Pointer(cGrammar)) @@ -622,7 +622,7 @@ func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogToke cEogTokens[i] = C.uint32_t(token) } - g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens))) + g := C.grammar_init(cGrammar, unsafe.SliceData(cTokens), C.size_t(len(cTokens)), unsafe.SliceData(cPieces), unsafe.SliceData(cEogTokens), C.size_t(len(cEogTokens))) if g == nil { return nil } diff --git a/model/process_text.go b/model/bytepairencoding.go similarity index 66% rename from model/process_text.go rename to model/bytepairencoding.go index 7b87ccc33..6bb9a003e 100644 --- a/model/process_text.go +++ b/model/bytepairencoding.go @@ -5,116 +5,13 @@ import ( "context" "iter" "log/slog" - "slices" "strings" - "sync" "github.com/dlclark/regexp2" heap "github.com/emirpasic/gods/v2/trees/binaryheap" "github.com/ollama/ollama/logutil" ) -type Special int32 - -const ( - SpecialBOS Special = iota - SpecialEOS -) - -const ( - TOKEN_TYPE_NORMAL = iota + 1 - TOKEN_TYPE_UNKNOWN - TOKEN_TYPE_CONTROL - TOKEN_TYPE_USER_DEFINED - TOKEN_TYPE_UNUSED - TOKEN_TYPE_BYTE -) - -type TextProcessor interface { - Encode(s string, addSpecial bool) ([]int32, error) - Decode([]int32) (string, error) - Is(int32, Special) bool - Vocabulary() *Vocabulary -} - -type Vocabulary struct { - Values []string - Types []int32 - Scores []float32 - Merges []string - - BOS, EOS, EOT int32 - AddBOS, AddEOS, AddEOT bool - - specialOnce sync.Once - special []string - - valuesOnce sync.Once - values map[string]int32 - - mergeOnce sync.Once - merge map[string]int32 -} - -func (v *Vocabulary) Is(id int32, special Special) bool { - switch special { - case SpecialBOS: - return id == v.BOS - case SpecialEOS: - return id == v.EOS || id == v.EOT - default: - return false - } -} - -func (v *Vocabulary) Encode(s string) int32 { - v.valuesOnce.Do(func() { - v.values = make(map[string]int32, len(v.Values)) - for i, value := range v.Values { - v.values[value] = int32(i) - } - }) - - if id, ok := v.values[s]; ok { - return id - } - - return -1 -} - -func (v *Vocabulary) Decode(id int32) string { - return v.Values[id] -} - -func (v *Vocabulary) SpecialVocabulary() []string { - v.specialOnce.Do(func() { - for i := range v.Values { - if slices.Contains([]int{105, 106}, i) { - v.special = append(v.special, v.Values[i]) - } else if v.Types[i] == TOKEN_TYPE_CONTROL { - v.special = append(v.special, v.Values[i]) - } - } - }) - - return v.special -} - -func (v *Vocabulary) Merge(left, right string) int { - v.mergeOnce.Do(func() { - v.merge = make(map[string]int32, len(v.Merges)) - for i, merge := range v.Merges { - v.merge[merge] = int32(i) - } - }) - - if id, ok := v.merge[left+" "+right]; ok { - return int(id) - } - - return -1 -} - type BytePairEncoding struct { pre *regexp2.Regexp vocab *Vocabulary @@ -304,27 +201,12 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { } } + slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids) + if addSpecial && len(ids) > 0 { - if bpe.vocab.AddBOS { - if ids[0] == bpe.vocab.BOS { - slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) - } - - slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS) - ids = append([]int32{bpe.vocab.BOS}, ids...) - } - - if bpe.vocab.AddEOS { - if ids[len(ids)-1] == bpe.vocab.EOS { - slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS) - } - - slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS) - ids = append(ids, bpe.vocab.EOS) - } + ids = bpe.vocab.addSpecials(ids) } - slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids) return ids, nil } @@ -352,6 +234,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String()) + slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String()) return sb.String(), nil } diff --git a/model/process_text_test.go b/model/bytepairencoding_test.go similarity index 100% rename from model/process_text_test.go rename to model/bytepairencoding_test.go diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 3156b0068..a87534c54 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -43,10 +43,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), Types: c.Ints("tokenizer.ggml.token_type"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), Layers: make([]Layer, c.Uint("block_count")), diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index d53eb6ccc..89d1788ef 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -60,12 +60,16 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), Types: c.Ints("tokenizer.ggml.token_type"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(1), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOT: int32(106), - AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false), + EOS: append( + []int32{ + int32(c.Uint("tokenizer.ggml.eos_token_id")), + int32(c.Uint("tokenizer.ggml.eot_token_id", 106)), + }, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), ImageProcessor: newImageProcessor(c), diff --git a/model/models/llama/model.go b/model/models/llama/model.go index c75d7eb2f..6e214f0fb 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -43,13 +43,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), Layers: make([]Layer, c.Uint("block_count")), diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index c94aa72f6..af5173a16 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -40,13 +40,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), ImageProcessor: newImageProcessor(c), diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index b93882a9e..0d384b94c 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -37,25 +37,25 @@ func New(c fs.Config) (model.Model, error) { } m := &Model{ - TextModel: textModel, - VisionModel: newVisionModel(c), - ImageProcessor: newImageProcessor(c), - MultiModalProjector: newMultiModalProjector(c), BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), + TextModel: textModel, + VisionModel: newVisionModel(c), + ImageProcessor: newImageProcessor(c), + MultiModalProjector: newMultiModalProjector(c), } m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 15571d9c2..547e2cb32 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -38,13 +38,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - // TODO: set EOT to EOS otherwise 0 will stop generation - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), ImageProcessor: newImageProcessor(c), diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 486554501..7de9b6eb1 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -34,12 +34,13 @@ func New(c fs.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), - AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), - EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), - EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")), - AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), }, ), TextModel: NewTextModel(c), diff --git a/model/process_text_spm.go b/model/sentencepiece.go similarity index 89% rename from model/process_text_spm.go rename to model/sentencepiece.go index b1cff7d27..7d725f04f 100644 --- a/model/process_text_spm.go +++ b/model/sentencepiece.go @@ -182,27 +182,12 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) } } + slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids) + if addSpecial && len(ids) > 0 { - if spm.vocab.AddBOS { - if ids[0] == spm.vocab.BOS { - slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS) - } - - slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS) - ids = append([]int32{spm.vocab.BOS}, ids...) - } - - if spm.vocab.AddEOS { - if ids[len(ids)-1] == spm.vocab.EOS { - slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS) - } - - slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS) - ids = append(ids, spm.vocab.EOS) - } + ids = spm.vocab.addSpecials(ids) } - slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids) return ids, nil } @@ -261,6 +246,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) { } } - slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String()) + slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String()) return sb.String(), nil } diff --git a/model/process_text_spm_test.go b/model/sentencepiece_test.go similarity index 100% rename from model/process_text_spm_test.go rename to model/sentencepiece_test.go diff --git a/model/textprocessor.go b/model/textprocessor.go new file mode 100644 index 000000000..4a36f2352 --- /dev/null +++ b/model/textprocessor.go @@ -0,0 +1,17 @@ +package model + +const ( + TOKEN_TYPE_NORMAL = iota + 1 + TOKEN_TYPE_UNKNOWN + TOKEN_TYPE_CONTROL + TOKEN_TYPE_USER_DEFINED + TOKEN_TYPE_UNUSED + TOKEN_TYPE_BYTE +) + +type TextProcessor interface { + Encode(s string, addSpecial bool) ([]int32, error) + Decode([]int32) (string, error) + Is(int32, Special) bool + Vocabulary() *Vocabulary +} diff --git a/model/vocabulary.go b/model/vocabulary.go new file mode 100644 index 000000000..24adbaca3 --- /dev/null +++ b/model/vocabulary.go @@ -0,0 +1,112 @@ +package model + +import ( + "log/slog" + "slices" + "sync" +) + +type Special int32 + +const ( + SpecialBOS Special = iota + SpecialEOS +) + +type Vocabulary struct { + Values []string + Types []int32 + Scores []float32 + Merges []string + + BOS, EOS []int32 + AddBOS, AddEOS bool + + specialOnce sync.Once + special []string + + valuesOnce sync.Once + values map[string]int32 + + mergeOnce sync.Once + merge map[string]int32 +} + +func (v *Vocabulary) Is(id int32, special Special) bool { + switch special { + case SpecialBOS: + return slices.Contains(v.BOS, id) + case SpecialEOS: + return slices.Contains(v.EOS, id) + default: + return false + } +} + +func (v *Vocabulary) addSpecials(ids []int32) []int32 { + if v.AddBOS && len(v.BOS) > 0 { + if slices.Contains(v.BOS, ids[0]) { + slog.Warn("adding bos token to prompt which already has it", "id", v.BOS) + } + + slog.Debug("adding bos token to prompt", "id", v.BOS) + ids = append([]int32{v.BOS[0]}, ids...) + } + + if v.AddEOS && len(v.EOS) > 0 { + if slices.Contains(v.BOS, ids[len(ids)-1]) { + slog.Warn("adding eos token to prompt which already has it", "id", v.EOS) + } + + slog.Debug("adding eos token to prompt", "id", v.EOS) + ids = append(ids, v.EOS[0]) + } + + return ids +} + +func (v *Vocabulary) Encode(s string) int32 { + v.valuesOnce.Do(func() { + v.values = make(map[string]int32, len(v.Values)) + for i, value := range v.Values { + v.values[value] = int32(i) + } + }) + + if id, ok := v.values[s]; ok { + return id + } + + return -1 +} + +func (v *Vocabulary) Decode(id int32) string { + return v.Values[id] +} + +func (v *Vocabulary) SpecialVocabulary() []string { + v.specialOnce.Do(func() { + for i := range v.Values { + if v.Types[i] == TOKEN_TYPE_CONTROL { + v.special = append(v.special, v.Values[i]) + } + } + }) + + return v.special +} + +func (v *Vocabulary) Merge(left, right string) int { + v.mergeOnce.Do(func() { + v.merge = make(map[string]int32, len(v.Merges)) + for i, merge := range v.Merges { + v.merge[merge] = int32(i) + } + }) + + if id, ok := v.merge[left+" "+right]; ok { + return int(id) + } + + return -1 +} diff --git a/sample/samplers.go b/sample/samplers.go index f0846c8dd..d395650d9 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -176,7 +176,7 @@ func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSa vocabIds[i] = uint32(i) } - grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)}) + grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS) if grammar == nil { return nil, errors.New("sample: failed to initialize grammar") }