From 7916f550099a4967d79b15edb52b9d23a016ab5f Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 3 Feb 2025 19:12:04 -0800 Subject: [PATCH] vocab: Use int32 for special tokens Special tokens are currently read as uint32 from the model metadata. However, all other parts of the system (including the tokenizer) use int32 to represent tokens so it is impossible to represent the high portion of the unsigned range. For consistency and to avoid casts, we should just use int32 everywhere. --- model/llama/model.go | 4 ++-- model/mllama/model.go | 4 ++-- model/process_text.go | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/model/llama/model.go b/model/llama/model.go index 294661740..6efcc9bb7 100644 --- a/model/llama/model.go +++ b/model/llama/model.go @@ -35,8 +35,8 @@ func New(c ml.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Uints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: c.Uint("tokenizer.ggml.bos_token_id"), - EOS: c.Uint("tokenizer.ggml.eos_token_id"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), }, ), Layers: make([]Layer, c.Uint("block_count")), diff --git a/model/mllama/model.go b/model/mllama/model.go index d0c59a3e2..e5b275b0b 100644 --- a/model/mllama/model.go +++ b/model/mllama/model.go @@ -26,8 +26,8 @@ func New(c ml.Config) (model.Model, error) { Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Uints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), - BOS: c.Uint("tokenizer.ggml.bos_token_id"), - EOS: c.Uint("tokenizer.ggml.eos_token_id"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), }, ), ImageProcessor: newImageProcessor(c), diff --git a/model/process_text.go b/model/process_text.go index 1610a884d..df1e68f4c 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -21,7 +21,7 @@ const ( type TextProcessor interface { Encode(string) ([]int32, error) Decode([]int32) (string, error) - Is(uint32, Special) bool + Is(int32, Special) bool } type Vocabulary struct { @@ -30,7 +30,7 @@ type Vocabulary struct { Scores []uint32 Merges []string - BOS, EOS uint32 + BOS, EOS int32 specialOnce sync.Once special []string @@ -42,7 +42,7 @@ type Vocabulary struct { merge map[string]int32 } -func (v *Vocabulary) Is(id uint32, special Special) bool { +func (v *Vocabulary) Is(id int32, special Special) bool { switch special { case SpecialBOS: return id == v.BOS @@ -111,7 +111,7 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { } } -func (bpe BytePairEncoding) Is(id uint32, special Special) bool { +func (bpe BytePairEncoding) Is(id int32, special Special) bool { return bpe.vocab.Is(id, special) }