mirror of
https://github.com/ollama/ollama.git
synced 2025-04-12 21:59:22 +02:00
2d rope
This commit is contained in:
parent
863ba57477
commit
557c641697
@ -130,6 +130,7 @@ type Tensor interface {
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Mul(ctx Context, t2 Tensor) Tensor
|
||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
@ -148,6 +149,8 @@ type Tensor interface {
|
||||
|
||||
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
Sin(ctx Context) Tensor
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
|
@ -727,6 +727,13 @@ func (t *Tensor) DType() ml.DType {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_neg(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@ -870,6 +877,20 @@ func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Sin(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sin(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Cos(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cos(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
@ -3,6 +3,7 @@ package mistral3
|
||||
import (
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@ -14,6 +15,18 @@ type PatchMerger struct {
|
||||
MergingLayer *nn.Linear `gguf:"merging_layer"`
|
||||
}
|
||||
|
||||
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
||||
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
||||
return x2.Neg(ctx).Concat(ctx, x1, 0)
|
||||
}
|
||||
|
||||
func applyRotaryPositionalEmbedding(ctx ml.Context, query, key, cos, sin ml.Tensor) (ml.Tensor, ml.Tensor) {
|
||||
query = query.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, query).Mul(ctx, sin))
|
||||
key = key.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, key).Mul(ctx, sin))
|
||||
return query, key
|
||||
}
|
||||
|
||||
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor {
|
||||
d := visionOutputs.Dim(0)
|
||||
imageGrid := visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Reshape(ctx, size.X, size.Y, d)
|
||||
@ -58,20 +71,26 @@ type VisionSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
q = q.Reshape(ctx, opts.headDim, opts.numHeads, q.Dim(1), batchSize)
|
||||
k = k.Reshape(ctx, opts.headDim, opts.numHeads, k.Dim(1), batchSize)
|
||||
v = v.Reshape(ctx, opts.headDim, opts.numHeads, v.Dim(1), batchSize)
|
||||
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
ropeType := uint32(24) // 2d vision rope
|
||||
q = q.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPEMulti(ctx, positionIDs, nil, uint32(opts.headDim/2), [4]int{0, opts.headDim / 2, opts.headDim / 2, 0}, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
attention := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
||||
query, key = applyRotaryPositionalEmbedding(ctx, query, key, cos, sin)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(opts.headDim)))
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
@ -82,9 +101,9 @@ type VisionMLP struct {
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
@ -94,16 +113,16 @@ type VisionEncoderLayer struct {
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
@ -116,7 +135,6 @@ type VisionModelOptions struct {
|
||||
numChannels int
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeScale float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
@ -127,24 +145,60 @@ type VisionModel struct {
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor {
|
||||
maxPatchesPerSide := m.imageSize / m.patchSize
|
||||
frequencies := m.headDim / 2
|
||||
frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide)
|
||||
frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide)
|
||||
for i := range frequencies {
|
||||
for j := range maxPatchesPerSide {
|
||||
frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim)))
|
||||
if i%2 == 0 {
|
||||
frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency
|
||||
} else {
|
||||
frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
h = h.Stack(ctx, 1, slices.Repeat([]ml.Tensor{h}, maxPatchesPerSide-1)...)
|
||||
h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
w = w.Stack(ctx, 2, slices.Repeat([]ml.Tensor{w}, maxPatchesPerSide-1)...)
|
||||
|
||||
inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide)
|
||||
inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0)
|
||||
return inverseFrequencies.Rows(ctx, positionIDs)
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
numPatchesW := pixelValues.Dim(0) / m.patchSize
|
||||
numPatchesH := pixelValues.Dim(1) / m.patchSize
|
||||
numPatches := numPatchesW * numPatchesH
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps)
|
||||
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
|
||||
|
||||
// Prepare position IDs for 2D rope
|
||||
positions := make([]int32, numPatches*4)
|
||||
for h := 0; h < numPatchesH; h++ {
|
||||
for w := 0; w < numPatchesW; w++ {
|
||||
positions := make([]int32, numPatches)
|
||||
for h := range numPatchesH {
|
||||
for w := range numPatchesW {
|
||||
idx := h*numPatchesW + w
|
||||
positions[idx] = 0 // time (unused)
|
||||
positions[numPatches+idx] = int32(h) // height
|
||||
positions[numPatches*2+idx] = int32(w) // width
|
||||
positions[numPatches*3+idx] = 0 // extra (unused)
|
||||
positions[idx] = int32(h*m.imageSize/m.patchSize + w)
|
||||
}
|
||||
}
|
||||
|
||||
@ -153,11 +207,14 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
func newVisionModel(c ml.Config) *VisionModel {
|
||||
@ -173,7 +230,6 @@ func newVisionModel(c ml.Config) *VisionModel {
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
|
||||
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
|
||||
ropeScale: c.Float("vision.rope.freq_scale", 1.0),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user