From 8dd2a81f8c82c7a8c2f0198007c284aa7b0bb69a Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 22 Mar 2025 22:33:39 -0700 Subject: [PATCH] wip --- convert/convert_mistral.go | 67 +++------ ml/backend.go | 3 + ml/backend/ggml/ggml.go | 35 +++++ .../ggml/ggml/src/ggml-metal/ggml-metal.m | 4 + model/models/mistral3/imageproc.go | 8 +- model/models/mistral3/model.go | 25 +++- model/models/mistral3/model_vision.go | 139 +++++++++++++++--- model/models/mistral3/multimodal_proj.go | 38 ----- 8 files changed, 195 insertions(+), 124 deletions(-) delete mode 100644 model/models/mistral3/multimodal_proj.go diff --git a/convert/convert_mistral.go b/convert/convert_mistral.go index 57e3d4ba4..ba7fa3e93 100644 --- a/convert/convert_mistral.go +++ b/convert/convert_mistral.go @@ -116,13 +116,16 @@ func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor { func (p *mistral3Model) Replacements() []string { return []string{ - // Text model replacements - "model.layers", "blk", + "language_model.model.norm", "output_norm", + "language_model.model.", "", + "language_model.", "", + "layers", "blk", + "transformer.layers", "blk", + "vision_tower", "v", + "ln_pre", "encoder_norm", "input_layernorm", "attn_norm", "post_attention_layernorm", "ffn_norm", - "lm_head", "output", - "model.embed_tokens.weight", "token_embd.weight", - "model.norm.weight", "output_norm.weight", + "embed_tokens", "token_embd", "self_attn.q_proj", "attn_q", "self_attn.k_proj", "attn_k", "self_attn.v_proj", "attn_v", @@ -130,50 +133,18 @@ func (p *mistral3Model) Replacements() []string { "mlp.down_proj", "ffn_down", "mlp.gate_proj", "ffn_gate", "mlp.up_proj", "ffn_up", - - // Language model replacements - "language_model.model.embed_tokens", "token_embd", - "language_model.model.layers", "blk", - "language_model.model.layers.*.input_layernorm", "attn_norm", - "language_model.model.layers.*.self_attn.q_proj", "attn_q", - "language_model.model.layers.*.self_attn.k_proj", "attn_k", - "language_model.model.layers.*.self_attn.v_proj", "attn_v", - "language_model.model.layers.*.self_attn.o_proj", "attn_output", - "language_model.model.layers.*.mlp.gate_proj", "ffn_gate", - "language_model.model.layers.*.mlp.down_proj", "ffn_down", - "language_model.model.layers.*.mlp.up_proj", "ffn_up", - "language_model.model.layers.*.post_attention_layernorm", "ffn_norm", - "language_model.lm_head", "output", - "language_model.model.norm", "output_norm", - - // Vision model replacements - map to shorter prefixes - "vision_tower", "v", + "attention.q_proj", "attn_q", + "attention.k_proj", "attn_k", + "attention.v_proj", "attn_v", + "attention.o_proj", "attn_output", + "attention_norm", "attn_norm", + "feed_forward", "mlp", + "feed_forward.gate_proj", "ffn_gate", + "feed_forward.down_proj", "ffn_down", + "feed_forward.up_proj", "ffn_up", "multi_modal_projector", "mm", - - // Vision transformer blocks - these should be updated accordingly - "vision_tower.transformer.layers", "v.blk", - "vision_tower.transformer.layers.*.attention_norm", "v.attn_norm", - "vision_tower.transformer.layers.*.attention.q_proj", "v.attn_q", - "vision_tower.transformer.layers.*.attention.k_proj", "v.attn_k", - "vision_tower.transformer.layers.*.attention.v_proj", "v.attn_v", - "vision_tower.transformer.layers.*.attention.o_proj", "v.attn_output", - "vision_tower.transformer.layers.*.feed_forward.gate_proj", "v.ffn_gate", - "vision_tower.transformer.layers.*.feed_forward.down_proj", "v.ffn_down", - "vision_tower.transformer.layers.*.feed_forward.up_proj", "v.ffn_up", - "vision_tower.transformer.layers.*.ffn_norm", "v.ffn_norm", - "vision_tower.ln_pre", "v.encoder_norm", - "vision_tower.patch_conv", "v.patch_conv", - "vision_tower.embeddings", "v.embeddings", - - // Alternative vision model paths - "vision_model.vision_model.embeddings", "v.embeddings", - "vision_model.vision_model", "v", - "vision_model.layers", "v.blk", - - // Multimodal projector components - "multi_modal_projector.patch_merger", "mm.patch_merger", - "multi_modal_projector.norm", "mm.norm", - "multi_modal_projector.linear", "mm.projection", + "ffn_norm", "ffn_norm", + "lm_head", "output", } } diff --git a/ml/backend.go b/ml/backend.go index 354faf432..31670ee07 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -144,6 +144,9 @@ type Tensor interface { Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor + RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, ropeDim uint32, sections [4]int, ropeType uint32, base, scale float32) Tensor + + IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Tanh(ctx Context) Tensor GELU(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index f6b017748..3558c7ad3 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -958,6 +958,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi } } +func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { + if ropeFactors == nil { + ropeFactors = &Tensor{b: t.b} + } + + dequant := t.t + if C.ggml_is_quantized(t.t._type) { + dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) + } + + return &Tensor{ + b: t.b, + t: C.ggml_rope_multi( + ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, + C.int(ropeDim), + (*C.int)(unsafe.Pointer(§ions[0])), + C.int(ropeType), + 131072, // YaRN n_ctx_train + C.float(ropeBase), + C.float(ropeScale), + 0., // YaRN ext_factor + 1., // YaRN attn_factor + 32., // YaRN beta_fast + 1., // YaRN beta_slow + ), + } +} + +func (t *Tensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_im2col(ctx.(*Context).ctx, t.t, weight.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1), true, C.GGML_TYPE_F32), + } +} + func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m index e4c093f9c..d68aab6cd 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m @@ -2186,6 +2186,10 @@ static void ggml_metal_encode_node( } break; case GGML_OP_MUL_MAT: { + if (ne00 != ne10) { + printf("mul_mat, ne00: %d, ne01: %d, ne02: %d, ne03: %d, ne10: %d, ne11: %d, ne12: %d, ne13: %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13); + } + GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne12 % ne02 == 0); diff --git a/model/models/mistral3/imageproc.go b/model/models/mistral3/imageproc.go index 2caa54091..907c580d0 100644 --- a/model/models/mistral3/imageproc.go +++ b/model/models/mistral3/imageproc.go @@ -21,8 +21,7 @@ func getNumImageTokens(imageSize, patchSize image.Point) image.Point { func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point { b := img.Bounds() - le := float64(longestEdge) - ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le) + ratio := math.Max(float64(b.Max.Y)/float64(longestEdge), float64(b.Max.X)/float64(longestEdge)) newSize := img.Bounds().Max @@ -80,17 +79,14 @@ func newImageProcessor(c ml.Config) ImageProcessor { imageSize: int(c.Uint("vision.image_size", 1540)), patchSize: int(c.Uint("vision.patch_size", 14)), numChannels: int(c.Uint("vision.num_channels", 3)), - longestEdge: int(c.Uint("vision.longest_edge", 1024)), + longestEdge: int(c.Uint("vision.longest_edge", 1540)), } } func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize}) - newImage := imageproc.Composite(img) newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) - data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) - return data, nil } diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 80e8f381e..b9ac2ab21 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -2,6 +2,7 @@ package mistral3 import ( "bytes" + "fmt" "image" "slices" @@ -59,19 +60,28 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er // Create tensor from image data pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.ImageProcessor.imageSize, - m.ImageProcessor.imageSize, + + // TODO (jmorganca): this should be returned from the + // image processor instead of hardcoded + 1036, m.ImageProcessor.numChannels, ) if err != nil { return nil, err } + fmt.Println("pixelValues", "shape", pixelValues.Shape(), "data", ml.Dump(ctx, pixelValues)) + // Forward pass through vision model visionOutputs := m.VisionModel.Forward(ctx, pixelValues) + // fmt.Println("visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) + // Project to text embedding space visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps) + // fmt.Println("visionOutputs after projector", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) + return visionOutputs, nil } @@ -85,16 +95,15 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { inputMultimodal := inp.Multimodal.(ml.Tensor) // Add special image tokens - using the imageTokenIndex from config - result = append(result, - input.Input{Token: int32(m.MultiModalProjector.imageTokenIndex)}, // Image token - input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // Image data - ) - - // Add image token placeholders - result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) + result = append(result, input.Input{Token: 10}) // [IMG] + result = append(result, input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}) // image data + result = append(result, slices.Repeat([]input.Input{{Token: 10}}, inputMultimodal.Dim(1)-1)...) // [IMG] placeholders + result = append(result, input.Input{Token: 13}) // [IMG_END] } } + fmt.Println("post tokenize", "result", result) + return result, nil } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 2826efe81..7f19aeb2c 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -1,6 +1,7 @@ package mistral3 import ( + "fmt" "math" "github.com/ollama/ollama/ml" @@ -9,31 +10,109 @@ import ( var batchSize int = 1 +type PatchMerger struct { + MergingLayer *nn.Linear `gguf:"merging_layer"` +} + +func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor { + // TODO: pass these in + w := 110 + h := 74 + // tokensPerImage := w * h + d := visionOutputs.Dim(0) + + // TODO: handle multiple images, this currently assumes one + fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) + + // Reshape to [h, w, hidden_size] + imageGrid := visionOutputs.Reshape(ctx, h, w, d) + fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid)) + + // TODO: load from ml.Config + spatialMergeSize := 2 + kernel := ctx.Output().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1) + fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel)) + + patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1) + fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches)) + + fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2)) + reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2)) + fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped)) + + return pm.MergingLayer.Forward(ctx, reshaped) +} + +type MultiModalProjector struct { + Norm *nn.RMSNorm `gguf:"norm"` + Linear1 *nn.Linear `gguf:"linear_1"` + Linear2 *nn.Linear `gguf:"linear_2"` + PatchMerger *PatchMerger `gguf:"patch_merger"` + + spatialMergeSize int + imageTokenIndex int + hasBias bool +} + +func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { + visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps) + fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) + visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs) + fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) + visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx) + fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) + return p.Linear2.Forward(ctx, visionOutputs) +} + +func newMultiModalProjector(c ml.Config) *MultiModalProjector { + return &MultiModalProjector{ + spatialMergeSize: int(c.Uint("spatial_merge_size", 2)), + imageTokenIndex: int(c.Uint("image_token_index", 10)), + hasBias: c.Bool("mm.projector_bias", false), + } +} + type VisionSelfAttention struct { - Query *nn.Linear `gguf:"attn_q"` - Key *nn.Linear `gguf:"attn_k"` - Value *nn.Linear `gguf:"attn_v"` - Output *nn.Linear `gguf:"attn_output"` - RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` } func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { headDim := opts.headDim + // fmt.Println("sa.Query", "shape", sa.Query.Weight.Shape(), "data", ml.Dump(ctx, sa.Query.Weight)) + query := sa.Query.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState) - query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - key = key.Reshape(ctx, headDim, opts.numHeads, batchSize) - value = value.Reshape(ctx, headDim, opts.numHeads, batchSize) + // fmt.Println("query", "shape", query.Shape(), "data", ml.Dump(ctx, query)) + // fmt.Println("key", "shape", key.Shape(), "data", ml.Dump(ctx, key)) + // fmt.Println("value", "shape", value.Shape(), "data", ml.Dump(ctx, value)) - ropeType := uint32(0) - query = query.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale) - key = key.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale) + query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) + key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) + value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) + + // fmt.Println("query permute", "shape", query.Shape(), "data", ml.Dump(ctx, query)) + // fmt.Println("key permute", "shape", key.Shape(), "data", ml.Dump(ctx, key)) + // fmt.Println("value permute", "shape", value.Shape(), "data", ml.Dump(ctx, value)) + // fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs)) + + // Multimodal rope + ropeType := uint32(24) + query = query.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale) + key = key.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale) + + // fmt.Println("query rope", "shape", query.Shape(), "data", ml.Dump(ctx, query)) + // fmt.Println("key rope", "shape", key.Shape(), "data", ml.Dump(ctx, key)) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) + // fmt.Println("attention", "shape", attention.Shape(), "data", ml.Dump(ctx, attention)) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + // fmt.Println("attention reshape", "shape", attention.Shape(), "data", ml.Dump(ctx, attention)) return sa.Output.Forward(ctx, attention) } @@ -54,7 +133,7 @@ type VisionEncoderLayer struct { SelfAttention *VisionSelfAttention FFNNorm *nn.RMSNorm `gguf:"ffn_norm"` - MLP *VisionMLP `gguf:"mlp"` + MLP *VisionMLP } func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { @@ -62,6 +141,7 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml // self attention hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + // fmt.Println("after attention norm", "eps", opts.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6})) hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts) hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -87,25 +167,36 @@ type VisionModelOptions struct { type VisionModel struct { PatchEmbedding *nn.Conv2D `gguf:"patch_conv"` - EncoderNorm *nn.LayerNorm `gguf:"encoder_norm"` + EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"` Layers []VisionEncoderLayer `gguf:"blk"` *VisionModelOptions } func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { - numPatchesH := m.imageSize / m.patchSize - numPatchesW := m.imageSize / m.patchSize + numPatchesH := pixelValues.Dim(1) / m.patchSize + numPatchesW := pixelValues.Dim(0) / m.patchSize numPatches := numPatchesH * numPatchesW - hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) + // fmt.Println("after patch embedding", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState)) hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) + // fmt.Println("after reshape", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState)) hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + // fmt.Println("after permute", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState)) - // Create position IDs - positions := make([]int32, numPatches) - for i := range positions { - positions[i] = int32(i) + // TODO: this seems to have incorrect output? + hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps) + // fmt.Println("after norm", "eps", m.VisionModelOptions.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6})) + + // Generate 4D position IDs (time, height, width, extra) for MROPE + var positions []int32 + for h := 0; h < numPatchesH; h++ { + for w := 0; w < numPatchesW; w++ { + positions = append(positions, 0) // unused + positions = append(positions, int32(h)) // height + positions = append(positions, int32(w)) // width + positions = append(positions, 0) // unused + } } positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) @@ -113,14 +204,14 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { panic(err) } - // Apply encoder normalization - hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.eps) + // fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs)) - // Process through transformer layers for _, layer := range m.Layers { hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions) } + // fmt.Println("after layers", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState)) + return hiddenState } @@ -135,7 +226,7 @@ func newVisionModel(c ml.Config) *VisionModel { imageSize: int(c.Uint("vision.image_size", 1540)), patchSize: int(c.Uint("vision.patch_size", 14)), numChannels: int(c.Uint("vision.num_channels", 3)), - eps: c.Float("vision.attention.layer_norm_epsilon", 1e-05), + eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5), ropeBase: c.Float("vision.rope.freq_base", 10000.0), ropeScale: c.Float("vision.rope.freq_scale", 1.0), }, diff --git a/model/models/mistral3/multimodal_proj.go b/model/models/mistral3/multimodal_proj.go deleted file mode 100644 index 7de40abd7..000000000 --- a/model/models/mistral3/multimodal_proj.go +++ /dev/null @@ -1,38 +0,0 @@ -package mistral3 - -import ( - "github.com/ollama/ollama/ml" - "github.com/ollama/ollama/ml/nn" -) - -type MultiModalProjector struct { - Norm *nn.RMSNorm `gguf:"norm"` - Projection *nn.Linear `gguf:"projection"` - - spatialMergeSize int - imageTokenIndex int - hasBias bool -} - -func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { - // Apply normalization - visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps) - - // If the spatial merge size is > 1, average pool the patches - if p.spatialMergeSize > 1 { - // Implementation depends on how the model handles spatial merging - // For simplicity, we'll use a spatial pooling approach - visionOutputs = visionOutputs.AvgPool2D(ctx, p.spatialMergeSize, p.spatialMergeSize, 0) - } - - // Project to text embedding dimension - return p.Projection.Forward(ctx, visionOutputs) -} - -func newMultiModalProjector(c ml.Config) *MultiModalProjector { - return &MultiModalProjector{ - spatialMergeSize: int(c.Uint("spatial_merge_size", 2)), - imageTokenIndex: int(c.Uint("image_token_index", 10)), - hasBias: c.Bool("mm.projector_bias", false), - } -}