diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 81d41f2ab9..2b16dc62e4 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index c2a526080d..631baeccdb 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: c.Float("rope.scaling.factor", 1.0), + ropeScale: 1, + // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights + // (8 instead of 1) + // ropeScale: c.Float("rope.scaling.factor", 1.0), }, } @@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextConfig.ropeGlobalBase } - return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil } type TextMLP struct {