diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 573138592..b9f9cc178 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -100,6 +100,10 @@ func (kv KV) Float(key string, defaultValue ...float32) float32 { return keyValue(kv, key, append(defaultValue, 0)...) } +func (kv KV) Bool(key string, defaultValue ...bool) bool { + return keyValue(kv, key, append(defaultValue, false)...) +} + func (kv KV) Strings(key string, defaultValue ...[]string) []string { r := keyValue(kv, key, &array{}) s := make([]string, r.size) @@ -120,7 +124,7 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { return s } -func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T { +func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { key = kv.Architecture() + "." + key } diff --git a/ml/backend.go b/ml/backend.go index 6e3f0516f..a742ee5c0 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -14,6 +14,7 @@ type Config interface { String(string, ...string) string Uint(string, ...uint32) uint32 Float(string, ...float32) float32 + Bool(string, ...bool) bool Strings(string, ...[]string) []string Uints(string, ...[]uint32) []uint32 diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 4fe029993..6106af867 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -37,7 +37,9 @@ func New(c ml.Config) (model.Model, error) { Types: c.Uints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), }, ), Layers: make([]Layer, c.Uint("block_count")), diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index f5521ce5c..9b35a2628 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -33,7 +33,9 @@ func New(c ml.Config) (model.Model, error) { Types: c.Uints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), }, ), ImageProcessor: newImageProcessor(c), diff --git a/model/process_text.go b/model/process_text.go index df1e68f4c..7083f36fd 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -30,7 +30,8 @@ type Vocabulary struct { Scores []uint32 Merges []string - BOS, EOS int32 + BOS, EOS int32 + AddBOS, AddEOS bool specialOnce sync.Once special []string @@ -281,6 +282,26 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { } } + if len(ids) > 0 { + if bpe.vocab.AddBOS { + if ids[0] == bpe.vocab.BOS { + slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) + } + + slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS) + ids = append([]int32{bpe.vocab.BOS}, ids...) + } + + if bpe.vocab.AddEOS { + if ids[len(ids)-1] == bpe.vocab.EOS { + slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS) + } + + slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS) + ids = append(ids, bpe.vocab.EOS) + } + } + return ids, nil }