diff --git a/convert/convert_gemma3.go b/convert/convert_gemma3.go index f43248dbd..fab5fcd85 100644 --- a/convert/convert_gemma3.go +++ b/convert/convert_gemma3.go @@ -1,6 +1,10 @@ package convert -import "github.com/ollama/ollama/fs/ggml" +import ( + "cmp" + + "github.com/ollama/ollama/fs/ggml" +) type gemma3Model struct { gemmaModel @@ -61,9 +65,9 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV { kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize - kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels + kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3) kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads - kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon + kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6) } kv["tokenizer.ggml.bos_token_id"] = uint32(2) diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index f9beccc24..a2e9c7f43 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -88,13 +88,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return nil, err } - positionIDs, err := ctx.FromIntSlice([]int32{0}, 1) - if err != nil { - return nil, err - } - - visionOutputs := m.VisionModel.Forward(ctx, pixelValues, positionIDs) - + visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize kernelSize := patchesPerImage * patchesPerImage / 256 diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index ec46ba9a8..9a09bf1fe 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -169,14 +169,14 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) + if multimodal != nil { visionOutputs := multimodal[0].Multimodal.(ml.Tensor) offset := multimodal[0].Index - 1 - visionOutputs.Dim(1) hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(0)) } - hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) - if len(m.Layers) == gemma27BLayerCount { m.TextOptions.largeModelScaling = true } diff --git a/model/models/gemma3/model_vision.go b/model/models/gemma3/model_vision.go index 49f9a5d29..a508f65bd 100644 --- a/model/models/gemma3/model_vision.go +++ b/model/models/gemma3/model_vision.go @@ -2,7 +2,6 @@ package gemma3 import ( "math" - "slices" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -69,52 +68,6 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts return hiddenState.Add(ctx, residual) } -type VisionEncoder struct { - Layers []VisionEncoderLayer -} - -func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) { - var intermediateHiddenStates []ml.Tensor - for i, layer := range e.Layers { - if slices.Contains(intermediateLayersIndices, uint32(i)) { - intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...)) - } - - hiddenState = layer.Forward(ctx, hiddenState, opts) - } - - return hiddenState, intermediateHiddenStates -} - -type PrecomputedAspectRatioEmbedding struct { - Embedding *nn.Embedding - Gate ml.Tensor `gguf:"gate"` -} - -func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { - embeddings := e.Embedding.Forward(ctx, aspectRatioIDs) - embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles) - if e.Gate != nil { - embeddings = embeddings.Mul(ctx, e.Gate) - } - - return hiddenState.Add(ctx, embeddings) -} - -type PrecomputedPositionEmbedding struct { - PositionEmbedding *nn.Embedding `gguf:"position_embd"` - PositionEmbeddingGate ml.Tensor `gguf:"position_embd.gate"` -} - -func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, numPositions int, opts *VisionModelOptions) ml.Tensor { - positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs) - if e.PositionEmbeddingGate != nil { - positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate) - } - - return hiddenState.Add(ctx, positionEmbedding) -} - type VisionModelOptions struct { hiddenSize, numHeads, numTiles int imageSize, patchSize int @@ -126,22 +79,31 @@ type VisionModel struct { PositionEmbedding *nn.Embedding `gguf:"position_embedding"` PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"` - Encoder *VisionEncoder `gguf:"blk"` + Layers []VisionEncoderLayer `gguf:"blk"` *VisionModelOptions } -func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs ml.Tensor) ml.Tensor { +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize) hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - positions := m.PositionEmbedding.Forward(ctx, positionIDs) - hiddenState = hiddenState.Add(ctx, positions) + positions := make([]int32, numPatches) + for i := range positions { + positions[i] = int32(i) + } - for _, layer := range m.Encoder.Layers { + positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) + if err != nil { + panic(err) + } + + hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs)) + + for _, layer := range m.Layers { hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions) } @@ -151,7 +113,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs ml.Tensor func newVisionModel(c ml.Config) *VisionModel { return &VisionModel{ - Encoder: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))}, + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")), VisionModelOptions: &VisionModelOptions{ hiddenSize: int(c.Uint("vision.embedding_length")), numHeads: int(c.Uint("vision.attention.head_count")), diff --git a/model/models/gemma3/process_image.go b/model/models/gemma3/process_image.go index a32e60e58..961794044 100644 --- a/model/models/gemma3/process_image.go +++ b/model/models/gemma3/process_image.go @@ -23,9 +23,8 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 { var pixelVals []float32 bounds := img.Bounds() - var rVals, gVals, bVals []float32 - for y := bounds.Min.Y; y < bounds.Max.Y; y++ { - for x := bounds.Min.X; x < bounds.Max.X; x++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { c := img.At(x, y) r, g, b, _ := c.RGBA() rVal := float32(r>>8) / 255.0 @@ -36,14 +35,9 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 { gVal = (gVal - mean[1]) / std[1] bVal = (bVal - mean[2]) / std[2] - rVals = append(rVals, rVal) - gVals = append(gVals, gVal) - bVals = append(bVals, bVal) + pixelVals = append(pixelVals, rVal, gVal, bVal) } } - pixelVals = append(pixelVals, rVals...) - pixelVals = append(pixelVals, gVals...) - pixelVals = append(pixelVals, bVals...) return pixelVals }