diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index fbefebe2d..29ffa2318 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -211,8 +211,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { // final logit softcap hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap)) hiddenState = hiddenState.Tanh(ctx) - hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)) - return hiddenState.Rows(ctx, outputs), nil + return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil } func init() {