vocab: Use int32 for special tokens

Special tokens are currently read as uint32 from the model metadata.
However, all other parts of the system (including the tokenizer) use
int32 to represent tokens so it is impossible to represent the high
portion of the unsigned range. For consistency and to avoid casts,
we should just use int32 everywhere.
This commit is contained in:
Jesse Gross 2025-02-03 19:12:04 -08:00 committed by Jesse Gross
parent d650ad398f
commit 7916f55009
3 changed files with 8 additions and 8 deletions

View File

@ -35,8 +35,8 @@ func New(c ml.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]Layer, c.Uint("block_count")),

View File

@ -26,8 +26,8 @@ func New(c ml.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
ImageProcessor: newImageProcessor(c),

View File

@ -21,7 +21,7 @@ const (
type TextProcessor interface {
Encode(string) ([]int32, error)
Decode([]int32) (string, error)
Is(uint32, Special) bool
Is(int32, Special) bool
}
type Vocabulary struct {
@ -30,7 +30,7 @@ type Vocabulary struct {
Scores []uint32
Merges []string
BOS, EOS uint32
BOS, EOS int32
specialOnce sync.Once
special []string
@ -42,7 +42,7 @@ type Vocabulary struct {
merge map[string]int32
}
func (v *Vocabulary) Is(id uint32, special Special) bool {
func (v *Vocabulary) Is(id int32, special Special) bool {
switch special {
case SpecialBOS:
return id == v.BOS
@ -111,7 +111,7 @@ func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
}
}
func (bpe BytePairEncoding) Is(id uint32, special Special) bool {
func (bpe BytePairEncoding) Is(id int32, special Special) bool {
return bpe.vocab.Is(id, special)
}