mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 14:56:45 +01:00
* Enable CUDA Graphs for gemma3n. Similar to https://github.com/ggml-org/llama.cpp/pull/14741, though ollama has a slightly different model graph than llama.cpp which requires different workaround checks. * Remove residual check by reshaping differently in gemma3n model This should make the heuristics more robust
360 lines
14 KiB
Go
360 lines
14 KiB
Go
package gemma3n
|
|
|
|
import (
|
|
"cmp"
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
"github.com/ollama/ollama/ml/nn/fast"
|
|
"github.com/ollama/ollama/ml/nn/rope"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
type TextModel struct {
|
|
TokenEmbedding *TextScaledWordEmbedding `gguf:"token_embd"`
|
|
|
|
*PerLayerProjector
|
|
|
|
AltupEmbd *nn.Linear `gguf:"altup_proj"`
|
|
AltupUnembd *nn.Linear `gguf:"altup_unembd_proj"`
|
|
|
|
TextLayers []TextLayer `gguf:"blk"`
|
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
|
|
|
TextOptions
|
|
}
|
|
|
|
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
|
// Create a tensor of a single float32 value of 1.0 to use for altup correction
|
|
one := ctx.Input().FromFloatSlice([]float32{1.0}, 1)
|
|
|
|
inputs := m.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(m.hiddenSize)))
|
|
inputsPerLayer := m.PerLayerProjector.Forward(ctx, batch, inputs, &m.TextOptions)
|
|
|
|
targetMagnitude := inputs.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
|
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
|
|
|
hiddenState := inputs.Repeat(ctx, 2, m.altupInputs-1)
|
|
altupProj := m.AltupEmbd.Forward(ctx, hiddenState)
|
|
altupProj = altupProj.Mul(ctx, targetMagnitude.Div(ctx, altupProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
|
|
|
hiddenStates := inputs.Concat(ctx, altupProj, 2)
|
|
|
|
firstSharedKeyValue := m.hiddenLayers - m.sharedKeyValueLayers
|
|
for i, layer := range m.TextLayers {
|
|
if i < firstSharedKeyValue {
|
|
cache.SetLayer(i)
|
|
} else if m.isLocal(i) {
|
|
cache.SetLayer(firstSharedKeyValue - 2)
|
|
} else {
|
|
cache.SetLayer(firstSharedKeyValue - 1)
|
|
}
|
|
|
|
var layerType int
|
|
ropeBase := m.ropeBase
|
|
if m.isLocal(i) {
|
|
layerType = 1
|
|
ropeBase = m.ropeBaseLocal
|
|
}
|
|
|
|
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
|
|
|
|
// inputPerLayer = inputsPerLayer[:, i, :]
|
|
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
|
|
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
|
|
}
|
|
|
|
// hiddenStates = hiddenStates[:, :, 0]
|
|
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
|
|
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
|
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
|
|
|
// hiddenState = hiddenStates[:, :, 1:]
|
|
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
|
|
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
|
|
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
|
|
|
hiddenStates = hiddenStates0.Concat(ctx, altupUnembdProj, 2)
|
|
|
|
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
|
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
|
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
|
|
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
|
return m.Output.Forward(ctx, hiddenStates), nil
|
|
}
|
|
|
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
|
ropeBase := m.ropeBase
|
|
if m.isLocal(layer) {
|
|
ropeBase = m.ropeBaseLocal
|
|
}
|
|
|
|
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
|
}
|
|
|
|
type TextScaledWordEmbedding struct {
|
|
*nn.Embedding
|
|
}
|
|
|
|
func (e TextScaledWordEmbedding) Forward(ctx ml.Context, inputIDs ml.Tensor, scale float64) ml.Tensor {
|
|
return e.Embedding.Forward(ctx, inputIDs).Scale(ctx, scale)
|
|
}
|
|
|
|
type PerLayerProjector struct {
|
|
TokenEmbedding *TextScaledWordEmbedding `gguf:"per_layer_token_embd"`
|
|
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
|
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
|
}
|
|
|
|
func (p PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
|
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, batch.Inputs.Dim(0), batch.Inputs.Dim(1))
|
|
|
|
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
|
perLayerProjection = perLayerProjection.Scale(ctx, math.Sqrt(float64(opts.hiddenSize)))
|
|
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
|
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
|
|
|
if inputsPerLayer != nil {
|
|
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
|
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
|
}
|
|
|
|
return perLayerProjection
|
|
}
|
|
|
|
type TextLayer struct {
|
|
*AltUp
|
|
*Laurel
|
|
|
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
Attention *TextAttention
|
|
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
|
|
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
MLP *TextMLP
|
|
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
|
|
|
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
|
PerLayerProjection *nn.Linear `gguf:"proj"`
|
|
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
|
}
|
|
|
|
func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, positions, one ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, activationSparsityScale float64, opts *TextOptions) ml.Tensor {
|
|
predictions := d.Predict(ctx, hiddenStates, opts)
|
|
active := opts.altupActive(ctx, predictions)
|
|
|
|
attn := d.AttentionNorm.Forward(ctx, active, opts.eps)
|
|
laurel := d.Laurel.Forward(ctx, attn, opts)
|
|
|
|
attn = d.Attention.Forward(ctx, attn, positions, cache, sharedKV, ropeBase, opts)
|
|
attn = d.PostAttentionNorm.Forward(ctx, attn, opts.eps)
|
|
attn = active.Add(ctx, attn)
|
|
attn = attn.Add(ctx, laurel).Scale(ctx, 1/math.Sqrt(2))
|
|
|
|
mlp := d.MLPNorm.Forward(ctx, attn, opts.eps)
|
|
mlp = d.MLP.Forward(ctx, mlp, activationSparsityScale)
|
|
mlp = d.PostMLPNorm.Forward(ctx, mlp, opts.eps)
|
|
mlp = attn.Add(ctx, mlp)
|
|
|
|
predictions = d.Correct(ctx, predictions, mlp, one, opts)
|
|
active = opts.altupActive(ctx, predictions)
|
|
if opts.altupCorrectScale {
|
|
active = d.ScaleCorrectedOutput(ctx, active)
|
|
}
|
|
|
|
active = d.PerLayerInputGate.Forward(ctx, active)
|
|
active = active.GELU(ctx)
|
|
active = active.Mul(ctx, perLayerInput)
|
|
|
|
active = d.PerLayerProjection.Forward(ctx, active)
|
|
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
|
|
|
// inactive := predictions[:, :, 1:]
|
|
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
|
|
active = inactive.Add(ctx, active)
|
|
|
|
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
|
|
return predictions0.Concat(ctx, active, 2)
|
|
}
|
|
|
|
type AltUp struct {
|
|
CorrectionScale ml.Tensor `gguf:"altup_correct_scale.weight"`
|
|
PredictionCoefficient *nn.Linear `gguf:"altup_predict_coef"`
|
|
CorrectionCoefficient *nn.Linear `gguf:"altup_correct_coef"`
|
|
Router *nn.Linear `gguf:"altup_router"`
|
|
RouterNorm *nn.RMSNorm `gguf:"altup_router_norm"`
|
|
}
|
|
|
|
func (a AltUp) computeRouterModalities(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
routerInputs := a.RouterNorm.Forward(ctx, hiddenStates, opts.eps).Scale(ctx, 1.0/float64(opts.hiddenSize))
|
|
return a.Router.Forward(ctx, routerInputs).Tanh(ctx)
|
|
}
|
|
|
|
func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
modalities := a.computeRouterModalities(ctx, opts.altupActive(ctx, hiddenStates), opts)
|
|
|
|
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
|
|
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
|
|
|
|
predictions := coefficients.Mulmat(ctx, hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx))
|
|
predictions = predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
|
return predictions.Add(ctx, hiddenStates)
|
|
}
|
|
|
|
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
innovation := activated.Sub(ctx, opts.altupActive(ctx, predictions))
|
|
innovation = innovation.Repeat(ctx, 2, opts.altupInputs)
|
|
|
|
modalities := a.computeRouterModalities(ctx, activated, opts)
|
|
coefficients := a.CorrectionCoefficient.Forward(ctx, modalities)
|
|
coefficients = coefficients.Add(ctx, one)
|
|
|
|
coefficients = coefficients.Reshape(ctx, 1, coefficients.Dim(0), coefficients.Dim(1))
|
|
coefficients = coefficients.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
|
corrected := innovation.Mul(ctx, coefficients)
|
|
corrected = corrected.Add(ctx, predictions)
|
|
return corrected
|
|
}
|
|
|
|
func (a AltUp) ScaleCorrectedOutput(ctx ml.Context, predictions ml.Tensor) ml.Tensor {
|
|
return predictions.Mul(ctx, a.CorrectionScale)
|
|
}
|
|
|
|
type Laurel struct {
|
|
LinearLeft *nn.Linear `gguf:"laurel_l"`
|
|
LinearRight *nn.Linear `gguf:"laurel_r"`
|
|
PostLaurelNorm *nn.RMSNorm `gguf:"laurel_post_norm"`
|
|
}
|
|
|
|
func (l Laurel) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
residual := hiddenStates
|
|
hiddenStates = l.LinearLeft.Forward(ctx, hiddenStates)
|
|
hiddenStates = l.LinearRight.Forward(ctx, hiddenStates)
|
|
hiddenStates = l.PostLaurelNorm.Forward(ctx, hiddenStates, opts.eps)
|
|
return hiddenStates.Add(ctx, residual)
|
|
}
|
|
|
|
type TextAttention struct {
|
|
Query *nn.Linear `gguf:"attn_q"`
|
|
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
|
Key *nn.Linear `gguf:"attn_k"`
|
|
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
|
Value *nn.Linear `gguf:"attn_v"`
|
|
Output *nn.Linear `gguf:"attn_output"`
|
|
}
|
|
|
|
func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, opts *TextOptions) ml.Tensor {
|
|
batchSize := hiddenStates.Dim(1)
|
|
|
|
query := attn.Query.Forward(ctx, hiddenStates)
|
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
|
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
|
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
|
|
|
var key, value ml.Tensor
|
|
if !sharedKV {
|
|
key = attn.Key.Forward(ctx, hiddenStates)
|
|
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
|
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
|
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
|
|
|
value = attn.Value.Forward(ctx, hiddenStates)
|
|
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
|
value = value.RMSNorm(ctx, nil, opts.eps)
|
|
}
|
|
|
|
attention := nn.Attention(ctx, query, key, value, 1., cache)
|
|
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
|
return attn.Output.Forward(ctx, attention)
|
|
}
|
|
|
|
type TextMLP struct {
|
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
|
Up *nn.Linear `gguf:"ffn_up"`
|
|
Down *nn.Linear `gguf:"ffn_down"`
|
|
}
|
|
|
|
func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSparsityScale float64) ml.Tensor {
|
|
upStates := mlp.Up.Forward(ctx, hiddenStates)
|
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates)
|
|
if activationSparsityScale > 0 {
|
|
mean := hiddenStates.Mean(ctx)
|
|
std := hiddenStates.Stddev(ctx).Scale(ctx, activationSparsityScale)
|
|
cutoff := mean.Add(ctx, std)
|
|
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
|
}
|
|
|
|
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
|
|
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
|
return hiddenStates
|
|
}
|
|
|
|
type TextOptions struct {
|
|
hiddenLayers int
|
|
hiddenSize int
|
|
hiddenSizePerLayerInput int
|
|
numHeads, numKVHeads int
|
|
keyLength, valueLength int
|
|
sharedKeyValueLayers int
|
|
|
|
altupActiveIndex int
|
|
altupInputs int
|
|
altupCorrectScale bool
|
|
|
|
eps float32
|
|
ropeBase float32
|
|
ropeBaseLocal float32
|
|
ropeScale float32
|
|
|
|
slidingWindowPattern []bool
|
|
activationSparsityScale []float32
|
|
}
|
|
|
|
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
|
// t[:, :, o.altupActiveIndex]
|
|
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
|
|
}
|
|
|
|
func (o *TextOptions) headDim() int {
|
|
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
|
}
|
|
|
|
func (o *TextOptions) isLocal(i int) bool {
|
|
return o.slidingWindowPattern[i]
|
|
}
|
|
|
|
func newTextModel(c fs.Config) *TextModel {
|
|
return &TextModel{
|
|
TextLayers: make([]TextLayer, c.Uint("block_count")),
|
|
TextOptions: TextOptions{
|
|
hiddenLayers: int(c.Uint("block_count")),
|
|
hiddenSize: int(c.Uint("embedding_length")),
|
|
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input")),
|
|
numHeads: int(c.Uint("attention.head_count")),
|
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
|
keyLength: int(c.Uint("attention.key_length")),
|
|
valueLength: int(c.Uint("attention.value_length")),
|
|
sharedKeyValueLayers: int(c.Uint("attention.shared_kv_layers")),
|
|
|
|
altupActiveIndex: int(c.Uint("altup.active_idx")),
|
|
altupInputs: int(c.Uint("altup.num_inputs")),
|
|
|
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
|
ropeBase: c.Float("rope.freq_base", 1_000_000),
|
|
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
|
|
ropeScale: c.Float("rope.freq_scale", 1.0),
|
|
|
|
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
|
activationSparsityScale: c.Floats("activation_sparsity_scale"),
|
|
},
|
|
}
|
|
}
|