mirror of
https://github.com/ollama/ollama.git
synced 2025-07-30 18:03:58 +02:00
FromFloatSlice and FromIntSlice return an error if the shape doesn't match the passed data or if memory can't be allocated. Since these are inputs, the memory being allocated is system memory rather than VRAM. In many cases, the caller can't really handle the error and panics. Empty and Zeros directly panic if they can't allocate memory. This makes things consistent by panicing for the first two cases, removing a fair amount of error handling code. This is also consistent with how Go typically handles these situations.
177 lines
6.4 KiB
Go
177 lines
6.4 KiB
Go
package mistral3
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
)
|
|
|
|
var batchSize int = 1
|
|
|
|
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
|
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
|
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
|
|
return x2.Neg(ctx).Concat(ctx, x1, 0)
|
|
}
|
|
|
|
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
|
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
|
}
|
|
|
|
type VisionSelfAttention struct {
|
|
Query *nn.Linear `gguf:"attn_q"`
|
|
Key *nn.Linear `gguf:"attn_k"`
|
|
Value *nn.Linear `gguf:"attn_v"`
|
|
Output *nn.Linear `gguf:"attn_output"`
|
|
}
|
|
|
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
query := sa.Query.Forward(ctx, hiddenStates)
|
|
key := sa.Key.Forward(ctx, hiddenStates)
|
|
value := sa.Value.Forward(ctx, hiddenStates)
|
|
|
|
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
|
|
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
|
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
|
|
|
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
|
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
|
|
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
|
return sa.Output.Forward(ctx, attention)
|
|
}
|
|
|
|
type VisionMLP struct {
|
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
|
Up *nn.Linear `gguf:"ffn_up"`
|
|
Down *nn.Linear `gguf:"ffn_down"`
|
|
}
|
|
|
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
|
return mlp.Down.Forward(ctx, hiddenStates)
|
|
}
|
|
|
|
type VisionEncoderLayer struct {
|
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
SelfAttention *VisionSelfAttention
|
|
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
MLP *VisionMLP
|
|
}
|
|
|
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
residual := hiddenStates
|
|
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
|
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
|
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
|
|
|
residual = hiddenStates
|
|
hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
|
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
|
return hiddenStates.Add(ctx, residual)
|
|
}
|
|
|
|
type VisionModelOptions struct {
|
|
hiddenSize int
|
|
numHeads int
|
|
headDim int
|
|
intermediateSize int
|
|
imageSize int
|
|
patchSize int
|
|
numChannels int
|
|
eps float32
|
|
ropeBase float32
|
|
}
|
|
|
|
type VisionModel struct {
|
|
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
|
|
EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
|
|
Layers []VisionEncoderLayer `gguf:"blk"`
|
|
|
|
*VisionModelOptions
|
|
}
|
|
|
|
func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor {
|
|
maxPatchesPerSide := m.imageSize / m.patchSize
|
|
frequencies := m.headDim / 2
|
|
frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide)
|
|
frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide)
|
|
for i := range frequencies {
|
|
for j := range maxPatchesPerSide {
|
|
frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim)))
|
|
if i%2 == 0 {
|
|
frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency
|
|
} else {
|
|
frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency
|
|
}
|
|
}
|
|
}
|
|
|
|
h := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
|
|
w := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
|
|
|
|
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
|
|
h = h.Repeat(ctx, 1, maxPatchesPerSide)
|
|
h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
w = w.Repeat(ctx, 2, maxPatchesPerSide)
|
|
|
|
inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide)
|
|
inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0)
|
|
return inverseFrequencies.Rows(ctx, positionIDs)
|
|
}
|
|
|
|
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
|
numPatchesW := pixelValues.Dim(0) / m.patchSize
|
|
numPatchesH := pixelValues.Dim(1) / m.patchSize
|
|
numPatches := numPatchesW * numPatchesH
|
|
|
|
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
|
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
|
|
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
|
|
|
|
// Prepare position IDs for 2D rope
|
|
positions := make([]int32, numPatches)
|
|
for h := range numPatchesH {
|
|
for w := range numPatchesW {
|
|
idx := h*numPatchesW + w
|
|
positions[idx] = int32(h*m.imageSize/m.patchSize + w)
|
|
}
|
|
}
|
|
|
|
positionIDs := ctx.Input().FromIntSlice(positions, len(positions))
|
|
|
|
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
|
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
|
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
|
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
|
|
|
for _, layer := range m.Layers {
|
|
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
|
|
}
|
|
|
|
return hiddenStates
|
|
}
|
|
|
|
func newVisionModel(c fs.Config) *VisionModel {
|
|
return &VisionModel{
|
|
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
|
VisionModelOptions: &VisionModelOptions{
|
|
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
|
|
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
|
headDim: int(c.Uint("vision.attention.key_length", 64)),
|
|
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
|
|
imageSize: int(c.Uint("vision.image_size", 1540)),
|
|
patchSize: int(c.Uint("vision.patch_size", 14)),
|
|
numChannels: int(c.Uint("vision.num_channels", 3)),
|
|
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
|
|
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
|
|
},
|
|
}
|
|
}
|