diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index b1dc7d779..255de92ab 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -462,7 +462,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0 panic("not implemented") } -func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor { +func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { panic("not implemented") } diff --git a/ml/backend.go b/ml/backend.go index cfb18d6a9..b71e99326 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -118,6 +118,53 @@ type Context interface { Layer(int) Context } +// RopeType represents different RoPE (Rotary Position Embedding) implementation types +type RopeType int + +// Available RoPE implementation types +const ( + RopeTypeNormal RopeType = iota // Standard RoPE implementation + RopeTypeNeox // NeoX-style RoPE implementation + RopeTypeMRoPE // Multi-scale RoPE implementation + RopeTypeVision // Vision-specific RoPE implementation +) + +type YarnConfig struct { + YarnCtxTrain int // Context size used during training (for YaRN scaling) + YarnExtFactor float32 // Extension factor for YaRN + YarnAttnFactor float32 // Attention scaling factor for YaRN + YarnBetaFast float32 // Fast decay parameter for YaRN + YarnBetaSlow float32 // Slow decay parameter for YaRN +} + +// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Recurrent Network) +func DefaultYarnConfig(nCtx int32) *YarnConfig { + return &YarnConfig{ + YarnCtxTrain: int(nCtx), + YarnExtFactor: 0.0, + YarnAttnFactor: 1.0, + YarnBetaFast: 32.0, + YarnBetaSlow: 1.0, + } +} + +// RoPEConfig holds configuration for Rotary Position Embedding +type RoPEConfig struct { + // Dim is the dimensionality for applying rotary embeddings + Dim uint32 + + // Type specifies the RoPE implementation variant + Type RopeType + + // Base controls frequency decay for the embeddings + Base float32 + + // Scale allows scaling the effective context length + Scale float32 + + *YarnConfig +} + type Tensor interface { Dim(n int) int Stride(n int) int @@ -141,7 +188,7 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor + RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor Tanh(ctx Context) Tensor GELU(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index b6f59ae0e..05daa389b 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -907,6 +907,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } } +// GGML RoPE types +// These are the types used in the C implementation of RoPE const ( ropeTypeNorm C.int = 0 ropeTypeNeox C.int = 2 @@ -914,7 +916,8 @@ const ( ropeTypeVision C.int = 24 ) -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { +// RoPE applies Rotary Position Embeddings to the tensor +func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { if ropeFactors == nil { ropeFactors = &Tensor{b: t.b} } @@ -924,19 +927,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) } + if config.YarnConfig == nil { + config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing + } + + // Map Go RopeType to C implementation constants + var ropeTypeC C.int + switch config.Type { + case ml.RopeTypeNormal: + ropeTypeC = ropeTypeNorm + case ml.RopeTypeNeox: + ropeTypeC = ropeTypeNeox + case ml.RopeTypeMRoPE: + ropeTypeC = ropeTypeMrope + case ml.RopeTypeVision: + ropeTypeC = ropeTypeVision + default: + ropeTypeC = ropeTypeNorm + } + return &Tensor{ b: t.b, t: C.ggml_rope_ext( - ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, - C.int(ropeDim), - C.int(ropeType), - 131072, // YaRN n_ctx_train - C.float(ropeBase), - C.float(ropeScale), - 0., // YaRN ext_factor - 1., // YaRN attn_factor - 32., // YaRN beta_fast - 1., // YaRN beta_slow + ctx.(*Context).ctx, + dequant, + positionIDs.(*Tensor).t, + ropeFactors.(*Tensor).t, + C.int(config.Dim), + ropeTypeC, + C.int(config.YarnCtxTrain), + C.float(config.Base), + C.float(config.Scale), + C.float(config.YarnExtFactor), + C.float(config.YarnAttnFactor), + C.float(config.YarnBetaFast), + C.float(config.YarnBetaSlow), ), } } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 67c69ee86..02875d980 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -13,10 +13,11 @@ import ( type Options struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int - eps, ropeBase, ropeScale float32 + eps float32 attnLogitSoftcap float32 finalLogitSoftcap float32 largeModelScaling bool + ropeConfig ml.RoPEConfig } type Model struct { @@ -55,10 +56,15 @@ func New(c ml.Config) (model.Model, error) { attnKeyLen: int(c.Uint("attention.key_length")), attnValLen: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base", 10000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), attnLogitSoftcap: c.Float("attn_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base", 10000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("attention.key_length"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } @@ -78,11 +84,10 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil + return key.RoPE(ctx, shift, nil, m.ropeConfig), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7d8b6577e..98f386847 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -13,9 +13,11 @@ import ( type TextOptions struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int - eps, ropeScale float32 - ropeLocalBase, ropeGlobalBase float32 + eps float32 largeModelScaling bool + + ropeLocalConfig ml.RoPEConfig + ropeGlobalConfig ml.RoPEConfig } type TextModel struct { @@ -56,15 +58,27 @@ func newTextModel(c ml.Config) *TextModel { ), Layers: make([]TextLayer, numBlocks), TextOptions: &TextOptions{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - attnKeyLen: int(c.Uint("attention.key_length", 256)), - attnValLen: int(c.Uint("attention.value_length", 256)), - eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), - ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), - ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + attnKeyLen: int(c.Uint("attention.key_length", 256)), + attnValLen: int(c.Uint("attention.value_length", 256)), + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), + + ropeLocalConfig: ml.RoPEConfig{ + Base: c.Float("rope.local.freq_base", 10000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("attention.key_length", 256), + Type: ml.RopeTypeNeox, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, + ropeGlobalConfig: ml.RoPEConfig{ + Base: c.Float("rope.global.freq_base", 1000000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("attention.key_length", 256), + Type: ml.RopeTypeNeox, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } @@ -86,17 +100,16 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) - ropeBase := opts.ropeLocalBase + ropeConfig := opts.ropeLocalConfig if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = opts.ropeGlobalBase + ropeConfig = opts.ropeGlobalConfig } q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, nil, ropeConfig) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, nil, ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeBase := m.TextOptions.ropeLocalBase + ropeConfig := m.ropeLocalConfig if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = m.TextOptions.ropeGlobalBase + ropeConfig = m.ropeGlobalConfig } - return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil + return key.RoPE(ctx, shift, nil, ropeConfig), nil } type TextMLP struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 5c173997b..1a8ed26c4 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -14,8 +14,8 @@ import ( type Options struct { hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + eps float32 + ropeConfig ml.RoPEConfig } type Model struct { @@ -54,9 +54,13 @@ func New(c ml.Config) (model.Model, error) { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base"), + Scale: c.Float("rope.freq_scale", 1), + Dim: c.Uint("rope.dimension_count"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } @@ -76,15 +80,14 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -97,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil } type MLP struct { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 1cf30d89b..2cc700c0d 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -20,15 +20,14 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -43,7 +42,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil } return key, nil @@ -198,8 +197,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, type TextModelOptions struct { hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + eps float32 + ropeConfig ml.RoPEConfig crossAttentionLayers []uint32 } @@ -240,10 +239,14 @@ func newTextModel(c ml.Config) *TextModel { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), crossAttentionLayers: c.Uints("attention.cross_attention_layers"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base"), + Scale: c.Float("rope.freq_scale", 1), + Dim: c.Uint("rope.dimension_count"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } }