diff --git a/convert/convert.go b/convert/convert.go index 015303e78..eb441715f 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -15,6 +15,11 @@ import ( type ModelParameters struct { Architectures []string `json:"architectures"` VocabSize uint32 `json:"vocab_size"` + TextModel TextParameters `json:"text_config"` +} + +type TextParameters struct { + VocabSize uint32 `json:"vocab_size"` } type AdapterParameters struct { @@ -185,6 +190,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error { conv = &gemmaModel{} case "Gemma2ForCausalLM": conv = &gemma2Model{} + case "Gemma3ForConditionalGeneration": + conv = &gemma3Model{} case "Phi3ForCausalLM": conv = &phi3Model{} case "Qwen2ForCausalLM": @@ -213,6 +220,11 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error { } vocabSize := int(p.VocabSize) + if vocabSize == 0 { + tVocabSize := int(p.TextModel.VocabSize) + vocabSize = tVocabSize + } + switch { case vocabSize > len(t.Vocabulary.Tokens): slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens)) diff --git a/convert/convert_gemma3.go b/convert/convert_gemma3.go new file mode 100644 index 000000000..c2be55707 --- /dev/null +++ b/convert/convert_gemma3.go @@ -0,0 +1,81 @@ +package convert + +import "github.com/ollama/ollama/fs/ggml" + +type gemma3Model struct { + gemmaModel + TextModel gemma3TextModel `json:"text_config"` + VisionModel gemma3VisionModel `json:"vision_config"` +} + +type gemma3TextModel struct { + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + HiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + RMSNormEPS float32 `json:"rms_norm_eps"` + HeadDim uint32 `json:"head_dim"` + SlidingWindow uint32 `json:"sliding_window"` + AttentionLogitSoftcap float32 `json:"attn_logit_softcapping"` + FinalLogitSoftcap float32 `json:"final_logit_softcapping"` + RopeLocalTheta float32 `json:"rope_local_base_freq"` + RopeGlobalTheta float32 `json:"rope_global_base_freq"` +} + +type gemma3VisionModel struct { + ImageSize uint32 `json:"image_size"` + NumChannels uint32 `json:"num_channels"` + HiddenLayers uint32 `json:"num_hidden_layers"` +} + +func (p *gemma3Model) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "gemma3" + kv["gemma3.context_length"] = p.TextModel.MaxPositionEmbeddings + kv["gemma3.embedding_length"] = p.TextModel.HiddenSize + kv["gemma3.block_count"] = p.TextModel.HiddenLayers + kv["gemma3.text.feed_forward_length"] = p.TextModel.IntermediateSize + kv["gemma3.attention.head_count"] = p.TextModel.NumAttentionHeads + kv["gemma3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads + kv["gemma3.text.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS + kv["gemma3.attention.key_length"] = p.TextModel.HeadDim + kv["gemma3.attention.value_length"] = p.TextModel.HeadDim + kv["gemma3.text.attention.sliding_window"] = p.TextModel.SlidingWindow + kv["gemma3.text.final_logit_softcapping"] = p.TextModel.FinalLogitSoftcap + kv["gemma3.text.rope.local.freq_base"] = p.TextModel.RopeLocalTheta + kv["gemma3.text.rope.global.freq_base"] = p.TextModel.RopeGlobalTheta + kv["tokenizer.ggml.bos_token_id"] = uint32(2) + kv["tokenizer.ggml.eot_token_id"] = uint32(1) + kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize + kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels + kv["gemma3.vision.block_count"] = p.VisionModel.HiddenLayers + return kv +} + +func (p *gemma3Model) Replacements() []string { + return []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.norm", "output_norm", + "vision_model.vision_model", "v", + "language_model.", "", + "model.layers", "blk", + "encoder.layers", "blk", + "vision_tower.vision_model.embeddings", "v", + "input_layernorm", "attn_norm", + "self_attn.q_proj", "attn_q", + "self_attn.q_norm", "attn_q_norm", + "self_attn.k_proj", "attn_k", + "self_attn.k_norm", "attn_k_norm", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "mlp.gate_proj", "ffn_gate", + "mlp.down_proj", "ffn_down", + "mlp.up_proj", "ffn_up", + "post_attention_layernorm", "post_attention_norm", + "pre_feedforward_layernorm", "ffn_norm", + "post_feedforward_layernorm", "post_ffw_norm", + } +} diff --git a/convert/tokenizer_spm.go b/convert/tokenizer_spm.go index 5e506087c..d8a012c08 100644 --- a/convert/tokenizer_spm.go +++ b/convert/tokenizer_spm.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "io/fs" + "log/slog" "os" + "reflect" "slices" "google.golang.org/protobuf/proto" @@ -15,6 +17,8 @@ import ( ) func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) { + slog.Debug("using spm vocabulary") + ast, err := parseAdditionalSpecialTokens(fsys) if err != nil { return nil, err @@ -43,8 +47,11 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) { v.Types = append(v.Types, int32(t)) default: tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) - if slices.Contains(ast, piece.GetPiece()) { - tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL) + for _, t := range ast { + if t.Content == piece.GetPiece() { + tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL) + break + } } v.Types = append(v.Types, tt) @@ -78,10 +85,16 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) { return cmp.Compare(i.id, j.id) }) - n := len(v.Tokens) - for i, t := range ts { - if t.id != i+n { - return nil, fmt.Errorf("invalid token id: %d", t.id) + for _, t := range ts { + if t.id < len(v.Tokens) { + if v.Tokens[t.id] == t.content { + slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id) + continue + } + return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id) + } + if t.id != len(v.Tokens) { + return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens)) } v.Tokens = append(v.Tokens, t.content) @@ -92,7 +105,15 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) { return &v, nil } -func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) { +type specialToken struct { + Content string `json:"content"` + Lstrip bool `json:"lstrip"` + Normalized bool `json:"normalized"` + Rstrip bool `json:"rstrip"` + SingleWord bool `json:"single_word"` +} + +func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) { f, err := fsys.Open("special_tokens_map.json") if errors.Is(err, os.ErrNotExist) { return nil, nil @@ -102,12 +123,43 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) { defer f.Close() var m struct { - AdditionalSpecialTokens []string `json:"additional_special_tokens"` + AdditionalSpecialTokens any `json:"additional_special_tokens"` } if err := json.NewDecoder(f).Decode(&m); err != nil { return nil, err } - return m.AdditionalSpecialTokens, nil + var ast []specialToken + + switch st := m.AdditionalSpecialTokens.(type) { + case []string: + for _, s := range st { + ast = append(ast, specialToken{Content: s}) + } + case []any: + for _, s := range st { + // marshal and unmarshal the object to get the special token + tMap := s.(map[string]any) + data, err := json.Marshal(tMap) + if err != nil { + return nil, err + } + + var token specialToken + err = json.Unmarshal(data, &token) + if err != nil { + return nil, err + } + + ast = append(ast, token) + } + + default: + slog.Warn("special token", "unknown token", reflect.TypeOf(st)) + } + + slog.Debug("spm tokenizer", "additional tokens", ast) + + return ast, nil } diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 8662c3b01..fe98a71b3 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -124,6 +124,15 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { return s } +func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { + r := keyValue(kv, key, &array{}) + s := make([]float32, r.size) + for i := range r.size { + s[i] = float32(r.values[i].(float32)) + } + return s +} + func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { key = kv.Architecture() + "." + key @@ -476,7 +485,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO // vocab graph 4*batch*(embedding+vocab)+embedding*vocab*105/128, ) - case "gemma", "gemma2": + case "gemma", "gemma2", "gemma3": fullOffload = max( 4*batch*(embedding+vocab), 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads), diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 22d8efb43..0c9e000ef 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -445,7 +445,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0 panic("not implemented") } -func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor { +func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor { panic("not implemented") } diff --git a/ml/backend.go b/ml/backend.go index 641175f0f..27c2d14d3 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -19,6 +19,7 @@ type Config interface { Strings(string, ...[]string) []string Uints(string, ...[]uint32) []uint32 + Floats(string, ...[]float32) []float32 } type Backend interface { @@ -135,7 +136,7 @@ type Tensor interface { Scale(ctx Context, s float64) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor + RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor Tanh(ctx Context) Tensor GELU(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 74512f337..8843ae7c1 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -893,10 +893,13 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } const ( - ropeTypeNorm C.int = iota + ropeTypeNorm C.int = 0 + ropeTypeNeox C.int = 2 + ropeTypeMrope C.int = 8 + ropeTypeVision C.int = 24 ) -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor { +func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { if ropeFactors == nil { ropeFactors = &Tensor{b: t.b} } @@ -911,8 +914,8 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi t: C.ggml_rope_ext( ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, C.int(ropeDim), - 131072, // YaRN n_ctx_train - ropeTypeNorm, // ROPE_TYPE_NORM + C.int(ropeType), + 131072, // YaRN n_ctx_train C.float(ropeBase), C.float(ropeScale), 0., // YaRN ext_factor diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go new file mode 100644 index 000000000..2ad9c5681 --- /dev/null +++ b/model/models/gemma2/model.go @@ -0,0 +1,206 @@ +package gemma2 + +import ( + "math" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + hiddenSize, numHeads, numKVHeads int + attnKeyLen, attnValLen int + eps, ropeBase, ropeScale float32 + attnLogitSoftcap float32 + finalLogitSoftcap float32 + largeModelScaling bool +} + +type Model struct { + model.Base + model.SentencePieceModel + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd? + + *Options +} + +const ( + gemma27BLayerCount = 46 +) + +func New(c ml.Config) (model.Model, error) { + m := Model{ + SentencePieceModel: model.NewSentencePieceModel( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Uints("tokenizer.ggml.token_type"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + }, + ), + Layers: make([]Layer, c.Uint("block_count")), + Options: &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + attnKeyLen: int(c.Uint("attention.key_length")), + attnValLen: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base", 10000.0), + ropeScale: c.Float("rope.freq_scale", 1.0), + attnLogitSoftcap: c.Float("attn_logit_softcapping"), + finalLogitSoftcap: c.Float("final_logit_softcapping"), + }, + } + + slidingWindowLen := int32(c.Uint("attention.sliding_window")) + m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) + + return &m, nil +} + +type SelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenState.Dim(1) + ropeType := uint32(2) + + q := sa.Query.Forward(ctx, hiddenState) + q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) + q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + + if opts.largeModelScaling { + q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize / opts.numHeads))) + } else { + q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen))) + } + + k := sa.Key.Forward(ctx, hiddenState) + k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) + k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + + v := sa.Value.Forward(ctx, hiddenState) + v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) + + cache.Put(ctx, k, v) + k, v, mask := cache.Get(ctx) + + q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + + kq := k.Mulmat(ctx, q) + + // logit softcap + kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap)) + kq = kq.Tanh(ctx) + kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap)) + + kq = kq.Add(ctx, mask) + kq = kq.Softmax(ctx) + + kqv := v.Mulmat(ctx, kq) + kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) + + return sa.Output.Forward(ctx, kqv) +} + +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *SelfAttention + PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"` + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP + PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"` +} + +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenState + + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) + hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) + hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps) + return hiddenState.Add(ctx, residual) +} + +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { + inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) + if err != nil { + return nil, err + } + + positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) + if err != nil { + return nil, err + } + + hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) + + if len(m.Layers) == gemma27BLayerCount { + m.Options.largeModelScaling = true + } + + for i, layer := range m.Layers { + cacheType := i % 2 + m.Cache.SetLayer(i) + wc := m.Cache.(*kvcache.WrapperCache) + wc.SetLayerType(cacheType) + hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + hiddenState = m.Output.Forward(ctx, hiddenState) + + // final logit softcap + hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap)) + hiddenState = hiddenState.Tanh(ctx) + hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)) + + outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) + if err != nil { + return nil, err + } + + return hiddenState.Rows(ctx, outputs), nil +} + +func init() { + model.Register("gemma2", New) +} diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go new file mode 100644 index 000000000..0f4944a49 --- /dev/null +++ b/model/models/gemma3/model.go @@ -0,0 +1,74 @@ +package gemma3 + +import ( + "fmt" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.SentencePieceModel + + //*VisionModel `gguf:"v,vision"` + *TextModel + + //Projector *nn.Linear `gguf:"mm.0"` + + ImageProcessor +} + +func New(c ml.Config) (model.Model, error) { + // Verify unified config + if c.Uint("vision.block_count") == 0 { + return nil, fmt.Errorf("non-unified vision model not supported") + } + m := Model{ + SentencePieceModel: model.NewSentencePieceModel( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Uints("tokenizer.ggml.token_type"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + }, + ), + ImageProcessor: newImageProcessor(c), + //VisionModel: newVisionModel(c), + TextModel: newTextModel(c), + } + + slidingWindowLen := int32(c.Uint("text.attention.sliding_window")) + m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) + + return &m, nil +} + +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { + inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) + if err != nil { + return nil, err + } + + positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) + if err != nil { + return nil, err + } + + outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) + if err != nil { + return nil, err + } + + return m.TextModel.Forward(ctx, inputs, positions, outputs, m.Cache), nil +} + +func init() { + model.Register("gemma3", New) +} diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go new file mode 100644 index 000000000..051e06c56 --- /dev/null +++ b/model/models/gemma3/model_text.go @@ -0,0 +1,193 @@ +package gemma3 + +import ( + "math" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" +) + +type TextOptions struct { + hiddenSize, numHeads, numKVHeads int + attnKeyLen, attnValLen int + eps, ropeScale float32 + ropeLocalBase, ropeGlobalBase float32 + finalLogitSoftcap float32 + largeModelScaling bool +} + +type TextModel struct { + model.Base + model.SentencePieceModel + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []TextLayer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *TextOptions +} + +const ( + gemma27BLayerCount = 46 +) + +const ( + cacheTypeSWA = iota + cacheTypeCausal +) + +func newTextModel(c ml.Config) *TextModel { + m := TextModel{ + SentencePieceModel: model.NewSentencePieceModel( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Uints("tokenizer.ggml.token_type"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + }, + ), + Layers: make([]TextLayer, c.Uint("block_count")), + TextOptions: &TextOptions{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + attnKeyLen: int(c.Uint("attention.key_length")), + attnValLen: int(c.Uint("attention.value_length")), + eps: c.Float("text.attention.layer_norm_rms_epsilon"), + ropeLocalBase: c.Float("text.rope.local.freq_base", 10000.0), + ropeGlobalBase: c.Float("text.rope.global.freq_base", 1000000.0), + ropeScale: c.Float("text.rope.freq_scale", 1.0), + finalLogitSoftcap: c.Float("text.final_logit_softcapping"), + }, + } + + slidingWindowLen := int32(c.Uint("text.attention.sliding_window")) + m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) + + return &m +} + +type TextSelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + batchSize := hiddenState.Dim(1) + ropeType := uint32(2) + + ropeBase := opts.ropeLocalBase + if (layer+1)%6 == 0 { + ropeBase = opts.ropeGlobalBase + } + + q := sa.Query.Forward(ctx, hiddenState) + q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) + q = sa.QueryNorm.Forward(ctx, q, opts.eps) + q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + + if opts.largeModelScaling { + q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) + } else { + q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen))) + } + + k := sa.Key.Forward(ctx, hiddenState) + k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) + k = sa.KeyNorm.Forward(ctx, k, opts.eps) + k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + + v := sa.Value.Forward(ctx, hiddenState) + v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) + + scaleFactor := 1.0 + kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) + + return sa.Output.Forward(ctx, kqv) +} + +func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + ropeBase := m.TextOptions.ropeLocalBase + if (layer+1)%6 == 0 { + ropeBase = m.TextOptions.ropeGlobalBase + } + + return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil +} + +type TextMLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +type TextLayer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *TextSelfAttention + PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"` + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *TextMLP + PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"` +} + +func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + residual := hiddenState + + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts) + hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) + hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps) + return hiddenState.Add(ctx, residual) +} + +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor { + hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) + + if len(m.Layers) == gemma27BLayerCount { + m.TextOptions.largeModelScaling = true + } + + for i, layer := range m.Layers { + // gemma alternates between the sliding window (local) and causal (global) + // kv cache every 6 layers + cacheType := cacheTypeSWA + if (i+1)%6 == 0 { + cacheType = cacheTypeCausal + } + cache.SetLayer(i) + wc := cache.(*kvcache.WrapperCache) + wc.SetLayerType(cacheType) + hiddenState = layer.Forward(ctx, i, hiddenState, positions, cache, m.TextOptions) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + hiddenState = m.Output.Forward(ctx, hiddenState) + + // final logit softcap + hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap)) + hiddenState = hiddenState.Tanh(ctx) + hiddenState = hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap)) + + return hiddenState.Rows(ctx, outputs) +} diff --git a/model/models/gemma3/process_image.go b/model/models/gemma3/process_image.go new file mode 100644 index 000000000..5cf963e88 --- /dev/null +++ b/model/models/gemma3/process_image.go @@ -0,0 +1,57 @@ +package gemma3 + +import ( + "image" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/imageproc" +) + +type ImageProcessor struct { + imageSize, numChannels int +} + +func newImageProcessor(c ml.Config) ImageProcessor { + return ImageProcessor{ + imageSize: int(c.Uint("vision.image_size")), + numChannels: int(c.Uint("vision.num_channels")), + } +} + +func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 { + var pixelVals []float32 + + bounds := img.Bounds() + var rVals, gVals, bVals []float32 + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + c := img.At(x, y) + r, g, b, _ := c.RGBA() + rVal := float32(r>>8) / 255.0 + gVal := float32(g>>8) / 255.0 + bVal := float32(b>>8) / 255.0 + + rVal = (rVal - mean[0]) / std[0] + gVal = (gVal - mean[1]) / std[1] + bVal = (bVal - mean[2]) / std[2] + + rVals = append(rVals, rVal) + gVals = append(gVals, gVal) + bVals = append(bVals, bVal) + } + } + pixelVals = append(pixelVals, rVals...) + pixelVals = append(pixelVals, gVals...) + pixelVals = append(pixelVals, bVals...) + + return pixelVals +} + +func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { + outputSize := image.Point{p.imageSize, p.imageSize} + newImage := imageproc.Composite(img) + newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) + + data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD) + return data, nil +} diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 1f27f522d..19a2ab8c4 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -76,14 +76,15 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads + ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -96,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil } type MLP struct { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 373589f9e..40c9a9707 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -20,14 +20,15 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads + ropeType := uint32(0) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -40,8 +41,9 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil } return key, nil diff --git a/model/models/models.go b/model/models/models.go index d0b68b320..ce1d2ce03 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -1,6 +1,8 @@ package models import ( + _ "github.com/ollama/ollama/model/models/gemma2" + _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/mllama" ) diff --git a/model/process_text.go b/model/process_text.go index 0d75a0ed0..cd1deb659 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -18,6 +18,15 @@ const ( SpecialEOS ) +const ( + TOKEN_TYPE_NORMAL = iota + 1 + TOKEN_TYPE_UNKNOWN + TOKEN_TYPE_CONTROL + TOKEN_TYPE_USER_DEFINED + TOKEN_TYPE_UNUSED + TOKEN_TYPE_BYTE +) + type TextProcessor interface { Encode(s string, addSpecial bool) ([]int32, error) Decode([]int32) (string, error) @@ -27,7 +36,7 @@ type TextProcessor interface { type Vocabulary struct { Values []string Types []uint32 - Scores []uint32 + Scores []float32 Merges []string BOS, EOS int32 @@ -76,7 +85,7 @@ func (v *Vocabulary) Decode(id int32) string { func (v *Vocabulary) SpecialVocabulary() []string { v.specialOnce.Do(func() { for i := range v.Values { - if v.Types[i] == 3 { + if v.Types[i] == TOKEN_TYPE_CONTROL { v.special = append(v.special, v.Values[i]) } } diff --git a/model/process_text_spm.go b/model/process_text_spm.go new file mode 100644 index 000000000..c0bc973f9 --- /dev/null +++ b/model/process_text_spm.go @@ -0,0 +1,221 @@ +package model + +import ( + "iter" + "log/slog" + "strings" + + "github.com/dlclark/regexp2" + queue "github.com/emirpasic/gods/v2/queues/priorityqueue" +) + +const spmWhitespaceSep = "▁" + +func replaceWhitespaceBySeperator(s string) string { + return strings.ReplaceAll(s, " ", spmWhitespaceSep) +} + +type SentencePieceModel struct { + maxTokenLen int + pre *regexp2.Regexp + vocab *Vocabulary +} + +func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel { + slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) + + counter := map[int]int{} + var maxTokenLen int + for cnt := range vocab.Types { + switch vocab.Types[cnt] { + case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED: + maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt])) + fallthrough + default: + counter[int(vocab.Types[cnt])] += 1 + } + } + + slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL], + "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], + "max token len", maxTokenLen) + + return SentencePieceModel{ + maxTokenLen: maxTokenLen, + pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2), + vocab: vocab, + } +} + +func (spm SentencePieceModel) Is(id int32, special Special) bool { + return spm.vocab.Is(id, special) +} + +func (spm *SentencePieceModel) split(s string) iter.Seq[string] { + return func(yield func(string) bool) { + for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) { + if !yield(m.String()) { + break + } + } + } +} + +func (spm SentencePieceModel) Encode(s string) ([]int32, error) { + fragments := []fragment{{value: s}} + for _, special := range spm.vocab.SpecialVocabulary() { + // TODO: process special tokens concurrently + id := spm.vocab.Encode(special) + for i := 0; i < len(fragments); i++ { + frag := fragments[i] + if len(frag.ids) > 0 { + continue + } + + var middle []fragment + switch i := strings.Index(frag.value, special); { + case i < 0: + middle = append(middle, frag) + case i > 0: + middle = append(middle, fragment{value: frag.value[:i]}) + fallthrough + default: + middle = append(middle, fragment{value: special, ids: []int32{id}}) + if rest := frag.value[i+len(special):]; rest != "" { + middle = append(middle, fragment{value: rest}) + } + } + + fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) + } + } + slog.Debug("fragments", "frags", fragments) + + var ids []int32 + for _, frag := range fragments { + if len(frag.ids) > 0 { + ids = append(ids, frag.ids...) + continue + } + + for split := range spm.split(frag.value) { + split = replaceWhitespaceBySeperator(split) + + var sb strings.Builder + sb.Write([]byte(split)) + if id := spm.vocab.Encode(sb.String()); id >= 0 { + ids = append(ids, id) + continue + } + + runes := []rune(sb.String()) + pq := queue.NewWith(func(a, b any) int { + priA := a.(*candidate) + priB := b.(*candidate) + if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) { + return -1 + } + return 1 + }) + + merges := make([]merge, len(runes)) + for r := range runes { + merges[r] = merge{ + p: r - 1, + n: r + 1, + runes: []rune{runes[r]}, + } + } + + slog.Debug("tokenizer", "merges", merges) + + pairwise := func(a, b int) *candidate { + if a < 0 || b >= len(runes) { + return nil + } + + left, right := string(merges[a].runes), string(merges[b].runes) + if id := spm.vocab.Encode(left + right); id >= 0 { + return &candidate{ + a: a, + b: b, + score: spm.vocab.Scores[id], + } + } + return nil + } + + for i := range len(runes) - 1 { + if pair := pairwise(i, i+1); pair != nil { + pq.Enqueue(pair) + } + } + + pqv := pq.Values() + for _, v := range pqv { + e := v.(*candidate) + slog.Debug("candidate", "candidate", e) + } + + for !pq.Empty() { + v, _ := pq.Dequeue() + pair := v.(*candidate) + left, right := merges[pair.a], merges[pair.b] + + slog.Debug("pair", "left", left, "right", right) + if len(left.runes) == 0 || len(right.runes) == 0 { + continue + } + + merges[pair.a].runes = append(left.runes, right.runes...) + merges[pair.b].runes = nil + merges[pair.a].n = right.n + if right.n < len(merges) { + merges[right.n].p = pair.a + } + + if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { + pq.Enqueue(pair) + } + + if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { + pq.Enqueue(pair) + } + } + + slog.Debug("merges", "merges", merges) + + for _, merge := range merges { + if len(merge.runes) > 0 { + if id := spm.vocab.Encode(string(merge.runes)); id >= 0 { + ids = append(ids, id) + } else { + slog.Debug("missing token", "token", string(merge.runes)) + } + } + } + } + } + slog.Debug("encoded", "ids", ids) + + return ids, nil +} + +type candidate struct { + a, b int + score float32 +} + +func (spm SentencePieceModel) Decode(ids []int32) (string, error) { + var sb strings.Builder + for _, id := range ids { + data := spm.vocab.Decode(id) + data = strings.ReplaceAll(data, spmWhitespaceSep, " ") + if _, err := sb.WriteString(data); err != nil { + return "", err + } + } + + slog.Debug("decoded", "ids", ids, "text", sb.String()) + return sb.String(), nil +} diff --git a/model/process_text_spm_test.go b/model/process_text_spm_test.go new file mode 100644 index 000000000..72bd629ce --- /dev/null +++ b/model/process_text_spm_test.go @@ -0,0 +1,110 @@ +package model + +import ( + "log/slog" + "os" + "path/filepath" + "slices" + "testing" + + "google.golang.org/protobuf/proto" + + "github.com/ollama/ollama/convert/sentencepiece" +) + +func loadSentencePieceVocab(t *testing.T) SentencePieceModel { + t.Helper() + + bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) + if err != nil { + t.Fatal(err) + } + + var spm sentencepiece.ModelProto + if err := proto.Unmarshal(bts, &spm); err != nil { + t.Fatal(err) + } + + preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+` + + var v Vocabulary + + for _, piece := range spm.GetPieces() { + v.Values = append(v.Values, piece.GetPiece()) + v.Scores = append(v.Scores, piece.GetScore()) + switch t := piece.GetType(); t { + case sentencepiece.ModelProto_SentencePiece_UNKNOWN, + sentencepiece.ModelProto_SentencePiece_CONTROL, + sentencepiece.ModelProto_SentencePiece_UNUSED, + sentencepiece.ModelProto_SentencePiece_BYTE: + v.Types = append(v.Types, uint32(t)) + default: + tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL) + // todo parse the special tokens file + // - this will roundtrip correctly but the and + // tokens aren't processed + v.Types = append(v.Types, tt) + } + } + + return NewSentencePieceModel(preTokenizer, &v) +} + +func TestSentencePieceEncode(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + slog.SetDefault(logger) + + tokenizer := loadSentencePieceVocab(t) + + t.Run("basic roundtrip", func(t *testing.T) { + t.Parallel() + + cases := []string{ + "hello", + "hello ", + "hello ", + " hello", + " hello ", + " hello ", + "hello world", + "请考试我的软件!12345", + "你好", + "Hello 你好 world!", + } + + for _, want := range cases { + ids, err := tokenizer.Encode(want) + if err != nil { + t.Fatal(err) + } + + if got, err := tokenizer.Decode(ids); err != nil { + t.Fatal(err) + } else if got != want { + t.Errorf("got %q, want %q [%#v]", got, want, ids) + } + } + }) + + t.Run("special tokens", func(t *testing.T) { + type candidate struct { + token string + ids []int32 + } + + cases := []candidate{ + {"", []int32{2}}, + {"", []int32{1}}, + } + + for _, want := range cases { + ids, err := tokenizer.Encode(want.token) + if err != nil { + t.Fatal(err) + } + if !slices.Equal(ids, want.ids) { + t.Errorf("got %#v, want %#v", ids, want.ids) + } + } + }) +} diff --git a/model/testdata/gemma2/tokenizer.model b/model/testdata/gemma2/tokenizer.model new file mode 100644 index 000000000..14a242262 Binary files /dev/null and b/model/testdata/gemma2/tokenizer.model differ