mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 20:37:31 +01:00
FromFloatSlice and FromIntSlice return an error if the shape doesn't match the passed data or if memory can't be allocated. Since these are inputs, the memory being allocated is system memory rather than VRAM. In many cases, the caller can't really handle the error and panics. Empty and Zeros directly panic if they can't allocate memory. This makes things consistent by panicing for the first two cases, removing a fair amount of error handling code. This is also consistent with how Go typically handles these situations.
254 lines
9.2 KiB
Go
254 lines
9.2 KiB
Go
package llama4
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
)
|
|
|
|
type VisionAttention struct {
|
|
Query *nn.Linear `gguf:"attn_q"`
|
|
Key *nn.Linear `gguf:"attn_k"`
|
|
Value *nn.Linear `gguf:"attn_v"`
|
|
Output *nn.Linear `gguf:"attn_output"`
|
|
}
|
|
|
|
// applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor.
|
|
// This is equivalent to the Pytorch implmentation using half rotations:
|
|
//
|
|
// cos, sin = torch.cos(freqs), torch.sin(freqs)
|
|
// cos = cos.unsqueeze(-1)
|
|
// sin = sin.unsqueeze(-1)
|
|
// t = t.reshape(*t.shape[:-1], -1, 2)
|
|
// t_out = (t * cos) + (_rotate_half(t) * sin)
|
|
// t_out = t_out.flatten(3)
|
|
//
|
|
// Which is equivalent to the Pytorch implementation using complex numbers:
|
|
//
|
|
// t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2))
|
|
// freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:]
|
|
// freqs_ci = freqs_ci.to(t_.device)
|
|
// t_out = torch.view_as_real(t_ * freqs_ci).flatten(3)
|
|
//
|
|
// Due to the 1) the dimensional and 2) the datatype limitations of current backends,
|
|
// we need to use a different approach to achieve the same result.
|
|
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
|
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
|
|
|
|
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
|
|
|
|
// t1 = t[..., 0::2]
|
|
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
|
|
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
|
|
|
|
// t2 = t[..., 1::2]
|
|
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
|
|
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
|
|
|
|
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
|
|
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
|
|
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
|
|
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
|
|
|
|
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
|
|
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
|
|
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
|
|
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
|
|
|
|
return cosOut.Add(ctx, sinOut)
|
|
}
|
|
|
|
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
|
|
headDim := opts.hiddenSize / opts.numHeads
|
|
|
|
query := sa.Query.Forward(ctx, hiddenState)
|
|
key := sa.Key.Forward(ctx, hiddenState)
|
|
value := sa.Value.Forward(ctx, hiddenState)
|
|
|
|
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), query.Dim(2))
|
|
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), key.Dim(2))
|
|
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), value.Dim(2))
|
|
|
|
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
|
|
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
|
|
|
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
|
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
|
|
return sa.Output.Forward(ctx, attention)
|
|
}
|
|
|
|
type VisionMLP struct {
|
|
FC1 *nn.Linear `gguf:"fc1"`
|
|
FC2 *nn.Linear `gguf:"fc2"`
|
|
}
|
|
|
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
|
|
hiddenStates = mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx)
|
|
hiddenStates = mlp.FC2.Forward(ctx, hiddenStates)
|
|
return hiddenStates
|
|
}
|
|
|
|
type VisionLayer struct {
|
|
InputLayerNorm *nn.LayerNorm `gguf:"attn_norm"`
|
|
*VisionAttention
|
|
|
|
PostAttentionNorm *nn.LayerNorm `gguf:"ffn_norm"`
|
|
*VisionMLP `gguf:"mlp"`
|
|
}
|
|
|
|
func (e *VisionLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
|
|
residual := hiddenStates
|
|
|
|
// self attention
|
|
hiddenStates = e.InputLayerNorm.Forward(ctx, hiddenStates, opts.eps)
|
|
hiddenStates = e.VisionAttention.Forward(ctx, hiddenStates, cos, sin, opts)
|
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
|
|
|
// MLP
|
|
residual = hiddenStates
|
|
hiddenStates = e.PostAttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
|
hiddenStates = e.VisionMLP.Forward(ctx, hiddenStates, opts)
|
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
|
|
|
return hiddenStates
|
|
}
|
|
|
|
type VisionAdapter struct {
|
|
FC1 *nn.Linear `gguf:"mlp.fc1"`
|
|
FC2 *nn.Linear `gguf:"mlp.fc2"`
|
|
}
|
|
|
|
func (a *VisionAdapter) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
|
|
patches := hiddenStates.Dim(1)
|
|
patchSize := int(math.Sqrt(float64(patches)))
|
|
|
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), patchSize, patchSize, hiddenStates.Dim(2))
|
|
|
|
channels, width, height, tiles := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
|
|
|
|
channels, width = int(float32(channels)/opts.pixelShuffleRatio), int(float32(width)*opts.pixelShuffleRatio)
|
|
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
|
|
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
|
channels, height = int(float32(channels)/opts.pixelShuffleRatio), int(float32(height)*opts.pixelShuffleRatio)
|
|
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
|
|
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
|
hiddenStates = hiddenStates.Reshape(ctx, channels, width*height, tiles)
|
|
|
|
hiddenStates = a.FC1.Forward(ctx, hiddenStates).GELU(ctx)
|
|
hiddenStates = a.FC2.Forward(ctx, hiddenStates).GELU(ctx)
|
|
return hiddenStates
|
|
}
|
|
|
|
type VisionOptions struct {
|
|
hiddenSize, numHeads int
|
|
imageSize, patchSize int
|
|
|
|
ropeTheta float32
|
|
eps float32
|
|
pixelShuffleRatio float32
|
|
}
|
|
|
|
type PatchEmbedding struct {
|
|
*nn.Linear
|
|
}
|
|
|
|
func (p *PatchEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
|
|
kernel := ctx.Input().Empty(ml.DTypeF32, opts.patchSize, opts.patchSize, hiddenStates.Dim(2))
|
|
hiddenStates = kernel.IM2Col(ctx, hiddenStates, opts.patchSize, opts.patchSize, 0, 0, 1, 1)
|
|
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2), hiddenStates.Dim(3))
|
|
return p.Linear.Forward(ctx, hiddenStates)
|
|
}
|
|
|
|
type VisionModel struct {
|
|
Layers []VisionLayer `gguf:"blk"`
|
|
|
|
*PatchEmbedding `gguf:"patch_embedding"`
|
|
ClassEmbedding ml.Tensor `gguf:"class_embedding"`
|
|
PositionalEmbedding ml.Tensor `gguf:"positional_embedding_vlm"`
|
|
|
|
LayerNormPre *nn.LayerNorm `gguf:"layernorm_pre"`
|
|
LayerNormPost *nn.LayerNorm `gguf:"layernorm_post"`
|
|
|
|
*VisionAdapter `gguf:"vision_adapter"`
|
|
|
|
*VisionOptions
|
|
}
|
|
|
|
func newVisionModel(c fs.Config) *VisionModel {
|
|
return &VisionModel{
|
|
Layers: make([]VisionLayer, c.Uint("vision.block_count")),
|
|
VisionOptions: &VisionOptions{
|
|
hiddenSize: int(c.Uint("vision.embedding_length")),
|
|
numHeads: int(c.Uint("vision.attention.head_count")),
|
|
imageSize: int(c.Uint("vision.image_size")),
|
|
patchSize: int(c.Uint("vision.patch_size")),
|
|
ropeTheta: float32(c.Float("vision.rope.freq_base")),
|
|
eps: c.Float("vision.layer_norm_epsilon"),
|
|
pixelShuffleRatio: float32(c.Float("vision.pixel_shuffle_ratio")),
|
|
},
|
|
}
|
|
}
|
|
|
|
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
|
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionOptions)
|
|
hiddenStates = hiddenStates.Concat(ctx, m.ClassEmbedding.Repeat(ctx, 2, hiddenStates.Dim(2)), 1)
|
|
|
|
hiddenStates = hiddenStates.Add(ctx, m.PositionalEmbedding)
|
|
hiddenStates = m.LayerNormPre.Forward(ctx, hiddenStates, m.eps)
|
|
|
|
cos, sin := m.rotaryEmbedding(ctx)
|
|
for _, layer := range m.Layers {
|
|
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions)
|
|
}
|
|
|
|
hiddenStates = m.LayerNormPost.Forward(ctx, hiddenStates, m.eps)
|
|
hiddenStates = hiddenStates.Pad(ctx, 0, -1, 0, 0)
|
|
hiddenStates = m.VisionAdapter.Forward(ctx, hiddenStates, m.VisionOptions)
|
|
return hiddenStates
|
|
}
|
|
|
|
// floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function
|
|
// which in turn mimics Python's // operator.
|
|
func floorDiv[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](a, b T) T {
|
|
if b == 0 {
|
|
panic("division by zero")
|
|
}
|
|
|
|
if (a >= 0 && b > 0) || (a <= 0 && b < 0) || a%b == 0 {
|
|
return a / b
|
|
}
|
|
|
|
return a/b - 1
|
|
}
|
|
|
|
func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
|
|
patchesPerSide := m.imageSize / m.patchSize
|
|
numPatches := patchesPerSide*patchesPerSide + 1
|
|
|
|
headDim := m.hiddenSize / m.numHeads
|
|
freqDim := headDim / 2
|
|
|
|
freqs := make([]float32, numPatches*freqDim)
|
|
for i := range numPatches - 1 {
|
|
for j := 0; j < freqDim; j += 2 {
|
|
positionX := i*freqDim/2 + j/2
|
|
positionY := (i+numPatches)*freqDim/2 + j/2
|
|
ropeFreq := math.Pow(float64(m.ropeTheta), float64(j)*2/float64(headDim))
|
|
freqs[positionX] = float32(float64(1+i-floorDiv(i, patchesPerSide)*patchesPerSide) / ropeFreq)
|
|
freqs[positionY] = float32(float64(1+floorDiv(i, patchesPerSide)) / ropeFreq)
|
|
}
|
|
}
|
|
|
|
ropeFreqs := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
|
|
|
|
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
|
|
return ropeFreqs.Cos(ctx), ropeFreqs.Sin(ctx)
|
|
}
|