mirror of
https://github.com/ollama/ollama.git
synced 2025-04-12 21:59:22 +02:00
wip
This commit is contained in:
parent
4530661799
commit
cfeca27133
@ -59,10 +59,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
// Create tensor from image data
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
|
||||
// TODO (jmorganca): this should be returned from the
|
||||
// image processor instead of hardcoded
|
||||
1036,
|
||||
1036, // TODO (jmorganca): this should be returned from ProcessImage
|
||||
m.ImageProcessor.numChannels,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mistral3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
@ -55,11 +56,9 @@ type MultiModalProjector struct {
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
|
||||
// fmt.Println("visionOutputs after norm", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
||||
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs)
|
||||
// fmt.Println("visionOutputs after patch merger", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
||||
visionOutputs = p.Linear1.Forward(ctx, visionOutputs).GELU(ctx)
|
||||
// fmt.Println("visionOutputs after linear1 and gelu", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs))
|
||||
visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
|
||||
visionOutputs = visionOutputs.GELU(ctx)
|
||||
return p.Linear2.Forward(ctx, visionOutputs)
|
||||
}
|
||||
|
||||
@ -79,40 +78,20 @@ type VisionSelfAttention struct {
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
headDim := opts.headDim
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
// fmt.Println("sa.Query", "shape", sa.Query.Weight.Shape(), "data", ml.Dump(ctx, sa.Query.Weight))
|
||||
q = q.Reshape(ctx, opts.headDim, opts.numHeads, q.Dim(1), batchSize)
|
||||
k = k.Reshape(ctx, opts.headDim, opts.numHeads, k.Dim(1), batchSize)
|
||||
v = v.Reshape(ctx, opts.headDim, opts.numHeads, v.Dim(1), batchSize)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
ropeType := uint32(24) // 2d vision rope
|
||||
q = q.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// fmt.Println("query", "shape", query.Shape(), "data", ml.Dump(ctx, query))
|
||||
// fmt.Println("key", "shape", key.Shape(), "data", ml.Dump(ctx, key))
|
||||
// fmt.Println("value", "shape", value.Shape(), "data", ml.Dump(ctx, value))
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
// fmt.Println("query permute", "shape", query.Shape(), "data", ml.Dump(ctx, query))
|
||||
// fmt.Println("key permute", "shape", key.Shape(), "data", ml.Dump(ctx, key))
|
||||
// fmt.Println("value permute", "shape", value.Shape(), "data", ml.Dump(ctx, value))
|
||||
// fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
|
||||
|
||||
// Multimodal rope
|
||||
ropeType := uint32(24)
|
||||
query = query.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
key = key.RoPEMulti(ctx, positionIDs, nil, uint32(headDim/2), [4]int{0, headDim / 2, headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// fmt.Println("query rope", "shape", query.Shape(), "data", ml.Dump(ctx, query))
|
||||
// fmt.Println("key rope", "shape", key.Shape(), "data", ml.Dump(ctx, key))
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
// fmt.Println("attention", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
|
||||
attention := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
// fmt.Println("attention reshape", "shape", attention.Shape(), "data", ml.Dump(ctx, attention))
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
@ -130,22 +109,19 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Visio
|
||||
type VisionEncoderLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *VisionMLP
|
||||
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// self attention
|
||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
// fmt.Println("after attention norm", "eps", opts.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
|
||||
fmt.Println("after attention norm", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// feed forward
|
||||
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
@ -177,24 +153,18 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
numPatchesW := pixelValues.Dim(0) / m.patchSize
|
||||
numPatches := numPatchesH * numPatchesW
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
// fmt.Println("after patch embedding", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
// fmt.Println("after reshape", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
// fmt.Println("after permute", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState))
|
||||
|
||||
// TODO: this seems to have incorrect output?
|
||||
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps)
|
||||
// fmt.Println("after norm", "eps", m.VisionModelOptions.eps, "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6}))
|
||||
|
||||
// Generate 4D position IDs (time, height, width, extra) for MROPE
|
||||
var positions []int32
|
||||
totalPositions := numPatchesH * numPatchesW
|
||||
positions := make([]int32, totalPositions*4)
|
||||
|
||||
for h := 0; h < numPatchesH; h++ {
|
||||
for w := 0; w < numPatchesW; w++ {
|
||||
positions = append(positions, 0) // unused
|
||||
positions = append(positions, int32(h)) // height
|
||||
positions = append(positions, int32(w)) // width
|
||||
positions = append(positions, 0) // unused
|
||||
index := h*numPatchesW + w
|
||||
positions[totalPositions+index] = int32(h)
|
||||
positions[totalPositions*2+index] = int32(w)
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,8 +173,6 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// fmt.Println("positionIDs", "shape", positionIDs.Shape(), "data", ml.Dump(ctx, positionIDs))
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user