add gemma vision encoder

This commit is contained in:
Michael Yang 2025-03-06 12:16:54 -08:00
parent 5f74d1fd47
commit 4b037a97dc
10 changed files with 337 additions and 34 deletions

View File

@ -13,13 +13,13 @@ import (
)
type ModelParameters struct {
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
type TextParameters struct {
VocabSize uint32 `json:"vocab_size"`
VocabSize uint32 `json:"vocab_size"`
}
type AdapterParameters struct {

View File

@ -4,8 +4,17 @@ import "github.com/ollama/ollama/fs/ggml"
type gemma3Model struct {
gemmaModel
TextModel gemma3TextModel `json:"text_config"`
VisionModel gemma3VisionModel `json:"vision_config"`
TextModel gemma3TextModel `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
ImageSize uint32 `json:"image_size"` // image_size 560
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
}
type gemma3TextModel struct {
@ -24,12 +33,6 @@ type gemma3TextModel struct {
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
}
type gemma3VisionModel struct {
ImageSize uint32 `json:"image_size"`
NumChannels uint32 `json:"num_channels"`
HiddenLayers uint32 `json:"num_hidden_layers"`
}
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"
@ -46,11 +49,18 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.text.final_logit_softcapping"] = p.TextModel.FinalLogitSoftcap
kv["gemma3.text.rope.local.freq_base"] = p.TextModel.RopeLocalTheta
kv["gemma3.text.rope.global.freq_base"] = p.TextModel.RopeGlobalTheta
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
kv["tokenizer.ggml.bos_token_id"] = uint32(2)
kv["tokenizer.ggml.eot_token_id"] = uint32(1)
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
kv["gemma3.vision.block_count"] = p.VisionModel.HiddenLayers
return kv
}
@ -59,11 +69,11 @@ func (p *gemma3Model) Replacements() []string {
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"vision_model.vision_model", "v",
"vision_tower.vision_model.embeddings", "v",
"vision_tower.vision_model", "v",
"language_model.", "",
"model.layers", "blk",
"encoder.layers", "blk",
"vision_tower.vision_model.embeddings", "v",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
@ -71,11 +81,14 @@ func (p *gemma3Model) Replacements() []string {
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.out_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"input_projection_weight", "input_projection.weight",
"multi_modal_projector", "mm",
}
}

View File

@ -135,7 +135,9 @@ type Tensor interface {
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
AvgPool1D(ctx Context, k, s, p int) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
Tanh(ctx Context) Tensor

View File

@ -947,6 +947,13 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_pool_1d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(s), C.int(p)),
}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {

View File

@ -1,10 +1,15 @@
package gemma3
import (
"fmt"
"bytes"
"encoding/binary"
"hash/fnv"
"image"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
@ -13,19 +18,30 @@ type Model struct {
model.Base
model.SentencePieceModel
//*VisionModel `gguf:"v,vision"`
*VisionModel `gguf:"v,vision"`
*TextModel
//Projector *nn.Linear `gguf:"mm.0"`
*MultiModalProjector `gguf:"mm"`
ImageProcessor
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection"`
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
return visionOutputs
}
func New(c ml.Config) (model.Model, error) {
// Verify unified config
if c.Uint("vision.block_count") == 0 {
return nil, fmt.Errorf("non-unified vision model not supported")
}
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@ -40,8 +56,8 @@ func New(c ml.Config) (model.Model, error) {
},
),
ImageProcessor: newImageProcessor(c),
//VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
}
slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
@ -50,7 +66,78 @@ func New(c ml.Config) (model.Model, error) {
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
if err != nil {
return nil, err
}
positionIDs, err := ctx.FromIntSlice([]int32{0}, 1)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, positionIDs)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize
kernelSize := patchesPerImage * patchesPerImage / 256
visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
return visionOutputs, nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var images []input.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
images = nil
}
} else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var embeddings ml.Tensor
if opts.Multimodal != nil {
embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
@ -66,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, m.Cache), nil
return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil
}
func init() {

View File

@ -160,9 +160,12 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
if embeddings == nil {
embeddings = m.TokenEmbedding.Forward(ctx, inputs)
}
hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.TextOptions.largeModelScaling = true

View File

@ -0,0 +1,171 @@
package gemma3
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize).Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
return hiddenState
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionEncoder struct {
Layers []VisionEncoderLayer
}
func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
var intermediateHiddenStates []ml.Tensor
for i, layer := range e.Layers {
if slices.Contains(intermediateLayersIndices, uint32(i)) {
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
}
hiddenState = layer.Forward(ctx, hiddenState, opts)
}
return hiddenState, intermediateHiddenStates
}
type PrecomputedAspectRatioEmbedding struct {
Embedding *nn.Embedding
Gate ml.Tensor `gguf:"gate"`
}
func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
if e.Gate != nil {
embeddings = embeddings.Mul(ctx, e.Gate)
}
return hiddenState.Add(ctx, embeddings)
}
type PrecomputedPositionEmbedding struct {
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
PositionEmbeddingGate ml.Tensor `gguf:"position_embd.gate"`
}
func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, numPositions int, opts *VisionModelOptions) ml.Tensor {
positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
if e.PositionEmbeddingGate != nil {
positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
}
return hiddenState.Add(ctx, positionEmbedding)
}
type VisionModelOptions struct {
hiddenSize, numHeads, numTiles int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
Encoder *VisionEncoder `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
positions := m.PositionEmbedding.Forward(ctx, positionIDs)
hiddenState = hiddenState.Add(ctx, positions)
for _, layer := range m.Encoder.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Encoder: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
},
}
}

View File

@ -8,12 +8,13 @@ import (
)
type ImageProcessor struct {
imageSize, numChannels int
imageSize, patchSize, numChannels int
}
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels")),
}
}

View File

@ -144,8 +144,6 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
return images
}
// remove the "alpha" channel by drawing over a prefilled image
//
// remove the "alpha" channel by drawing over a prefilled image
//
//nolint:unused

View File

@ -21,6 +21,8 @@ type SentencePieceModel struct {
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePieceModel)(nil)
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
@ -61,7 +63,7 @@ func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
}
}
func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
@ -196,7 +198,26 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
}
}
}
slog.Debug("encoded", "ids", ids)
if addSpecial && len(ids) > 0 {
if spm.vocab.AddBOS {
if ids[0] == spm.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
}
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
ids = append([]int32{spm.vocab.BOS}, ids...)
}
if spm.vocab.AddEOS {
if ids[len(ids)-1] == spm.vocab.EOS {
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
}
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
ids = append(ids, spm.vocab.EOS)
}
}
return ids, nil
}