From e6b561005ec681ed972836047b119586888859d9 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 31 Mar 2025 13:42:15 -0700 Subject: [PATCH] fix patch batch --- model/models/mistral3/model.go | 54 ++++++++++++++++++++++++--- model/models/mistral3/model_text.go | 1 - model/models/mistral3/model_vision.go | 42 --------------------- 3 files changed, 49 insertions(+), 48 deletions(-) diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index c2db3335f..2f16e7dac 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -7,6 +7,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) @@ -41,6 +42,47 @@ func New(c ml.Config) (model.Model, error) { return m, nil } +type PatchMerger struct { + MergingLayer *nn.Linear `gguf:"merging_layer"` +} + +func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor { + d := visionOutputs.Dim(0) + imageGrid := visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Reshape(ctx, size.X, size.Y, d) + kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d) + patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1) + reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2)) + return pm.MergingLayer.Forward(ctx, reshaped) +} + +type MultiModalProjector struct { + Norm *nn.RMSNorm `gguf:"norm"` + Linear1 *nn.Linear `gguf:"linear_1"` + Linear2 *nn.Linear `gguf:"linear_2"` + PatchMerger *PatchMerger `gguf:"patch_merger"` + + spatialMergeSize int + eps float32 + patchSize int +} + +func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point) (ml.Tensor, image.Point) { + visionOutputs = p.Norm.Forward(ctx, visionOutputs, p.eps) + patchSizes := image.Point{size.X / p.patchSize, size.Y / p.patchSize} + visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs, patchSizes, p.spatialMergeSize) + visionOutputs = p.Linear1.Forward(ctx, visionOutputs) + visionOutputs = visionOutputs.GELU(ctx) + return p.Linear2.Forward(ctx, visionOutputs), image.Point{patchSizes.X / p.spatialMergeSize, patchSizes.Y / p.spatialMergeSize} +} + +func newMultiModalProjector(c ml.Config) *MultiModalProjector { + return &MultiModalProjector{ + spatialMergeSize: int(c.Uint("spatial_merge_size", 2)), + eps: c.Float("text_config.rms_norm_eps", 1e-5), + patchSize: int(c.Uint("vision.patch_size", 14)), + } +} + func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { if len(m.VisionModel.Layers) == 0 { return nil, model.ErrNoVisionModel @@ -80,19 +122,21 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er // that can be processed together. func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input - for _, inp := range inputs { if inp.Multimodal == nil { result = append(result, inp) } else { inputMultimodal := inp.Multimodal.([]ml.Tensor) for i, row := range inputMultimodal { - result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.Dim(1)}) // Image data - result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Dim(1)-1)...) // [IMG] + // [IMG] + result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.Dim(1) + 1}) + result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Dim(1)-1)...) if i == len(inputMultimodal)-1 { - result = append(result, input.Input{Token: 13}) // [IMG_END] + // [IMG_END] + result = append(result, input.Input{Token: 13}) } else { - result = append(result, input.Input{Token: 12}) // [IMG_BREAK] + // [IMG_BREAK] + result = append(result, input.Input{Token: 12}) } } } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index f280b1340..03b8caa0c 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -111,7 +111,6 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor // image embeddings for _, image := range batch.Multimodal { visionOutputs := image.Multimodal.(ml.Tensor) - // TODO (jmorganca): this fails on metal ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index a68d28a33..26120afcc 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -1,7 +1,6 @@ package mistral3 import ( - "image" "math" "slices" @@ -11,10 +10,6 @@ import ( var batchSize int = 1 -type PatchMerger struct { - MergingLayer *nn.Linear `gguf:"merging_layer"` -} - func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) @@ -25,43 +20,6 @@ func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Te return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) } -func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor { - d := visionOutputs.Dim(0) - imageGrid := visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Reshape(ctx, size.X, size.Y, d) - kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d) - patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1) - reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2)) - return pm.MergingLayer.Forward(ctx, reshaped) -} - -type MultiModalProjector struct { - Norm *nn.RMSNorm `gguf:"norm"` - Linear1 *nn.Linear `gguf:"linear_1"` - Linear2 *nn.Linear `gguf:"linear_2"` - PatchMerger *PatchMerger `gguf:"patch_merger"` - - spatialMergeSize int - eps float32 - patchSize int -} - -func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point) (ml.Tensor, image.Point) { - visionOutputs = p.Norm.Forward(ctx, visionOutputs, p.eps) - patchSizes := image.Point{size.X / p.patchSize, size.Y / p.patchSize} - visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs, patchSizes, p.spatialMergeSize) - visionOutputs = p.Linear1.Forward(ctx, visionOutputs) - visionOutputs = visionOutputs.GELU(ctx) - return p.Linear2.Forward(ctx, visionOutputs), image.Point{patchSizes.X / p.spatialMergeSize, patchSizes.Y / p.spatialMergeSize} -} - -func newMultiModalProjector(c ml.Config) *MultiModalProjector { - return &MultiModalProjector{ - spatialMergeSize: int(c.Uint("spatial_merge_size", 2)), - eps: c.Float("text_config.rms_norm_eps", 1e-5), - patchSize: int(c.Uint("vision.patch_size", 14)), - } -} - type VisionSelfAttention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"`