This commit is contained in:
jmorganca 2025-03-23 01:01:23 -07:00
parent 4530661799
commit cfeca27133
2 changed files with 23 additions and 58 deletions

View File

@ -59,10 +59,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
// Create tensor from image data
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
// TODO (jmorganca): this should be returned from the
// image processor instead of hardcoded
1036,
1036, // TODO (jmorganca): this should be returned from ProcessImage
m.ImageProcessor.numChannels,
)
if err != nil {

View File

@ -1,6 +1,7 @@
package mistral3
import (
"fmt"
"math"
"github.com/ollama/ollama/ml"
@ -55,11 +56,9 @@ type MultiModalProjector struct {
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
// fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
// fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx)
// fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
visionOutputs = visionOutputs.GELU(ctx)
return p.Linear2.Forward(ctx, visionOutputs)
}
@ -79,40 +78,20 @@ type VisionSelfAttention struct {
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.headDim
q := sa.Query.Forward(ctx, hiddenState)
k := sa.Key.Forward(ctx, hiddenState)
v := sa.Value.Forward(ctx, hiddenState)
// fmt.Println("sa.Query", "shape", sa.Query.Weight.Shape(), "data", ml.Dump(ctx, sa.Query.Weight))
q = q.Reshape(ctx, opts.headDim, opts.numHeads, q.Dim(1), batchSize)
k = k.Reshape(ctx, opts.headDim, opts.numHeads, k.Dim(1), batchSize)
v = v.Reshape(ctx, opts.headDim, opts.numHeads, v.Dim(1), batchSize)
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
ropeType := uint32(24) // 2d vision rope
q = q.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
k = k.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
// fmt.Println("query", "shape", query.Shape(), "data", ml.Dump(ctx, query))
// fmt.Println("key", "shape", key.Shape(), "data", ml.Dump(ctx, key))
// fmt.Println("value", "shape", value.Shape(), "data", ml.Dump(ctx, value))
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
// fmt.Println("query permute", "shape", query.Shape(), "data", ml.Dump(ctx, query))
// fmt.Println("key permute", "shape", key.Shape(), "data", ml.Dump(ctx, key))
// fmt.Println("value permute", "shape", value.Shape(), "data", ml.Dump(ctx, value))
// fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
// Multimodal rope
ropeType := uint32(24)
query = query.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
key = key.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
// fmt.Println("query rope", "shape", query.Shape(), "data", ml.Dump(ctx, query))
// fmt.Println("key rope", "shape", key.Shape(), "data", ml.Dump(ctx, key))
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
// fmt.Println("attention", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
attention := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
// fmt.Println("attention reshape", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
return sa.Output.Forward(ctx, attention)
}
@ -130,22 +109,19 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Visio
type VisionEncoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *VisionSelfAttention
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *VisionMLP
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
// fmt.Println("after attention norm", "eps", opts.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
fmt.Println("after attention norm", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
@ -177,24 +153,18 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatchesW := pixelValues.Dim(0) / m.patchSize
numPatches := numPatchesH * numPatchesW
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
// fmt.Println("after patch embedding", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
// fmt.Println("after reshape", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// fmt.Println("after permute", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
// TODO: this seems to have incorrect output?
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps)
// fmt.Println("after norm", "eps", m.VisionModelOptions.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
// Generate 4D position IDs (time, height, width, extra) for MROPE
var positions []int32
totalPositions := numPatchesH * numPatchesW
positions := make([]int32, totalPositions*4)
for h := 0; h < numPatchesH; h++ {
for w := 0; w < numPatchesW; w++ {
positions = append(positions, 0) // unused
positions = append(positions, int32(h)) // height
positions = append(positions, int32(w)) // width
positions = append(positions, 0) // unused
index := h*numPatchesW + w
positions[totalPositions+index] = int32(h)
positions[totalPositions*2+index] = int32(w)
}
}
@ -203,8 +173,6 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
panic(err)
}
// fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
}