From 14c8594baf883200d6cac3d6b678218236e78015 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 1 Apr 2025 14:03:48 -0700 Subject: [PATCH] ml: structured rope config to allow specifying context len This commit refactors the Rotary Position Embedding (RoPE) implementation across the codebase to use a structured configuration approach instead of individual parameters. Key changes: - Add new RoPEConfig struct with fields for dimension, type, base frequency, and scaling - Add RopeType enum to formalize different RoPE implementation variants - Add YarnConfig struct and related configuration for YaRN (Yet Another RoPE extensioN) context extension - Update RoPE method signature across all tensor interfaces and implementations - Refactor all model implementations (llama, gemma2, gemma3, mllama) to use the new configuration structure This change improves code organization, makes the RoPE configuration more explicit, and provides better support for different RoPE variants and context extension methods. --- kvcache/causal_test.go | 2 +- ml/backend.go | 49 ++++++++++++++++++++++++++++- ml/backend/ggml/ggml.go | 47 +++++++++++++++++++++------- model/models/gemma2/model.go | 19 +++++++----- model/models/gemma3/model_text.go | 51 +++++++++++++++++++------------ model/models/llama/model.go | 21 +++++++------ model/models/mllama/model_text.go | 21 +++++++------ 7 files changed, 153 insertions(+), 57 deletions(-) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index b1dc7d779..255de92ab 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -462,7 +462,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0 panic("not implemented") } -func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor { +func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { panic("not implemented") } diff --git a/ml/backend.go b/ml/backend.go index cfb18d6a9..b71e99326 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -118,6 +118,53 @@ type Context interface { Layer(int) Context } +// RopeType represents different RoPE (Rotary Position Embedding) implementation types +type RopeType int + +// Available RoPE implementation types +const ( + RopeTypeNormal RopeType = iota // Standard RoPE implementation + RopeTypeNeox // NeoX-style RoPE implementation + RopeTypeMRoPE // Multi-scale RoPE implementation + RopeTypeVision // Vision-specific RoPE implementation +) + +type YarnConfig struct { + YarnCtxTrain int // Context size used during training (for YaRN scaling) + YarnExtFactor float32 // Extension factor for YaRN + YarnAttnFactor float32 // Attention scaling factor for YaRN + YarnBetaFast float32 // Fast decay parameter for YaRN + YarnBetaSlow float32 // Slow decay parameter for YaRN +} + +// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Recurrent Network) +func DefaultYarnConfig(nCtx int32) *YarnConfig { + return &YarnConfig{ + YarnCtxTrain: int(nCtx), + YarnExtFactor: 0.0, + YarnAttnFactor: 1.0, + YarnBetaFast: 32.0, + YarnBetaSlow: 1.0, + } +} + +// RoPEConfig holds configuration for Rotary Position Embedding +type RoPEConfig struct { + // Dim is the dimensionality for applying rotary embeddings + Dim uint32 + + // Type specifies the RoPE implementation variant + Type RopeType + + // Base controls frequency decay for the embeddings + Base float32 + + // Scale allows scaling the effective context length + Scale float32 + + *YarnConfig +} + type Tensor interface { Dim(n int) int Stride(n int) int @@ -141,7 +188,7 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor + RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor Tanh(ctx Context) Tensor GELU(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index b6f59ae0e..05daa389b 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -907,6 +907,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } } +// GGML RoPE types +// These are the types used in the C implementation of RoPE const ( ropeTypeNorm C.int = 0 ropeTypeNeox C.int = 2 @@ -914,7 +916,8 @@ const ( ropeTypeVision C.int = 24 ) -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { +// RoPE applies Rotary Position Embeddings to the tensor +func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor { if ropeFactors == nil { ropeFactors = &Tensor{b: t.b} } @@ -924,19 +927,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) } + if config.YarnConfig == nil { + config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing + } + + // Map Go RopeType to C implementation constants + var ropeTypeC C.int + switch config.Type { + case ml.RopeTypeNormal: + ropeTypeC = ropeTypeNorm + case ml.RopeTypeNeox: + ropeTypeC = ropeTypeNeox + case ml.RopeTypeMRoPE: + ropeTypeC = ropeTypeMrope + case ml.RopeTypeVision: + ropeTypeC = ropeTypeVision + default: + ropeTypeC = ropeTypeNorm + } + return &Tensor{ b: t.b, t: C.ggml_rope_ext( - ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, - C.int(ropeDim), - C.int(ropeType), - 131072, // YaRN n_ctx_train - C.float(ropeBase), - C.float(ropeScale), - 0., // YaRN ext_factor - 1., // YaRN attn_factor - 32., // YaRN beta_fast - 1., // YaRN beta_slow + ctx.(*Context).ctx, + dequant, + positionIDs.(*Tensor).t, + ropeFactors.(*Tensor).t, + C.int(config.Dim), + ropeTypeC, + C.int(config.YarnCtxTrain), + C.float(config.Base), + C.float(config.Scale), + C.float(config.YarnExtFactor), + C.float(config.YarnAttnFactor), + C.float(config.YarnBetaFast), + C.float(config.YarnBetaSlow), ), } } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 67c69ee86..02875d980 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -13,10 +13,11 @@ import ( type Options struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int - eps, ropeBase, ropeScale float32 + eps float32 attnLogitSoftcap float32 finalLogitSoftcap float32 largeModelScaling bool + ropeConfig ml.RoPEConfig } type Model struct { @@ -55,10 +56,15 @@ func New(c ml.Config) (model.Model, error) { attnKeyLen: int(c.Uint("attention.key_length")), attnValLen: int(c.Uint("attention.value_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base", 10000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), attnLogitSoftcap: c.Float("attn_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base", 10000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("attention.key_length"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } @@ -78,11 +84,10 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil + return key.RoPE(ctx, shift, nil, m.ropeConfig), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7d8b6577e..98f386847 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -13,9 +13,11 @@ import ( type TextOptions struct { hiddenSize, numHeads, numKVHeads int attnKeyLen, attnValLen int - eps, ropeScale float32 - ropeLocalBase, ropeGlobalBase float32 + eps float32 largeModelScaling bool + + ropeLocalConfig ml.RoPEConfig + ropeGlobalConfig ml.RoPEConfig } type TextModel struct { @@ -56,15 +58,27 @@ func newTextModel(c ml.Config) *TextModel { ), Layers: make([]TextLayer, numBlocks), TextOptions: &TextOptions{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - attnKeyLen: int(c.Uint("attention.key_length", 256)), - attnValLen: int(c.Uint("attention.value_length", 256)), - eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), - ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), - ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + attnKeyLen: int(c.Uint("attention.key_length", 256)), + attnValLen: int(c.Uint("attention.value_length", 256)), + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), + + ropeLocalConfig: ml.RoPEConfig{ + Base: c.Float("rope.local.freq_base", 10000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("attention.key_length", 256), + Type: ml.RopeTypeNeox, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, + ropeGlobalConfig: ml.RoPEConfig{ + Base: c.Float("rope.global.freq_base", 1000000.0), + Scale: c.Float("rope.freq_scale", 1.0), + Dim: c.Uint("attention.key_length", 256), + Type: ml.RopeTypeNeox, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } @@ -86,17 +100,16 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) - ropeType := uint32(2) - ropeBase := opts.ropeLocalBase + ropeConfig := opts.ropeLocalConfig if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = opts.ropeGlobalBase + ropeConfig = opts.ropeGlobalConfig } q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, nil, ropeConfig) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, nil, ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) @@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - ropeBase := m.TextOptions.ropeLocalBase + ropeConfig := m.ropeLocalConfig if (layer+1)%gemmaGlobalCacheCount == 0 { - ropeBase = m.TextOptions.ropeGlobalBase + ropeConfig = m.ropeGlobalConfig } - return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil + return key.RoPE(ctx, shift, nil, ropeConfig), nil } type TextMLP struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 5c173997b..1a8ed26c4 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -14,8 +14,8 @@ import ( type Options struct { hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + eps float32 + ropeConfig ml.RoPEConfig } type Model struct { @@ -54,9 +54,13 @@ func New(c ml.Config) (model.Model, error) { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base"), + Scale: c.Float("rope.freq_scale", 1), + Dim: c.Uint("rope.dimension_count"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } @@ -76,15 +80,14 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -97,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil } type MLP struct { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 1cf30d89b..2cc700c0d 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -20,15 +20,14 @@ type TextSelfAttention struct { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) - query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) @@ -43,7 +42,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil } return key, nil @@ -198,8 +197,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, type TextModelOptions struct { hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + eps float32 + ropeConfig ml.RoPEConfig crossAttentionLayers []uint32 } @@ -240,10 +239,14 @@ func newTextModel(c ml.Config) *TextModel { numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), eps: c.Float("attention.layer_norm_rms_epsilon"), - ropeBase: c.Float("rope.freq_base"), - ropeScale: c.Float("rope.freq_scale", 1), - ropeDim: c.Uint("rope.dimension_count"), crossAttentionLayers: c.Uints("attention.cross_attention_layers"), + ropeConfig: ml.RoPEConfig{ + Base: c.Float("rope.freq_base"), + Scale: c.Float("rope.freq_scale", 1), + Dim: c.Uint("rope.dimension_count"), + Type: ml.RopeTypeNormal, + YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))), + }, }, } }