From a70820daa0a25024cfd857a528e717e3ac00a8e0 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 12 Mar 2025 10:17:57 -0700 Subject: [PATCH] models/gemma3: remove final logit softcap (#9692) Softcap isn't in the whitepaper/implementation for the language model so we should remove it. There is no discernible difference in output with it removed. --- model/models/gemma3/model_text.go | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 5b5e2d6ed..7a88c0921 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -15,7 +15,6 @@ type TextOptions struct { attnKeyLen, attnValLen int eps, ropeScale float32 ropeLocalBase, ropeGlobalBase float32 - finalLogitSoftcap float32 largeModelScaling bool } @@ -57,16 +56,15 @@ func newTextModel(c ml.Config) *TextModel { ), Layers: make([]TextLayer, numBlocks), TextOptions: &TextOptions{ - hiddenSize: int(c.Uint("embedding_length")), - numHeads: int(c.Uint("attention.head_count")), - numKVHeads: int(c.Uint("attention.head_count_kv")), - attnKeyLen: int(c.Uint("attention.key_length", 256)), - attnValLen: int(c.Uint("attention.value_length", 256)), - eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), - ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), - ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: c.Float("rope.freq_scale", 1.0), - finalLogitSoftcap: c.Float("final_logit_softcapping", 30.0), + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + attnKeyLen: int(c.Uint("attention.key_length", 256)), + attnValLen: int(c.Uint("attention.value_length", 256)), + eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), + ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), + ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), + ropeScale: c.Float("rope.freq_scale", 1.0), }, } @@ -245,10 +243,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) - hiddenState = m.Output.Forward(ctx, hiddenState) - - // final logit softcap - hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap)) - hiddenState = hiddenState.Tanh(ctx) - return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap)) + return m.Output.Forward(ctx, hiddenState) }