From 4346c2409dec7bb182aa91d236f0759d592a1b42 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 7 Mar 2025 09:30:10 -0800 Subject: [PATCH] fix drift from main --- kvcache/causal_test.go | 4 ++++ model/models/gemma2/model.go | 36 ++++++++++++++++++++--------- model/models/gemma3/model_text.go | 25 +++++++++++++------- model/models/gemma3/model_vision.go | 2 +- model/process_text_spm_test.go | 4 ++-- 5 files changed, 49 insertions(+), 22 deletions(-) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 0c9e000ef..431a79b53 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -441,6 +441,10 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor { panic("not implemented") } +func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor { + panic("not implemented") +} + func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { panic("not implemented") } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 2ad9c5681..a82d68d37 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -64,6 +64,7 @@ func New(c ml.Config) (model.Model, error) { slidingWindowLen := int32(c.Uint("attention.sliding_window")) m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) + m.Cache.SetConfig(ml.CacheConfig{}) return &m, nil } @@ -84,7 +85,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) if opts.largeModelScaling { - q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize / opts.numHeads))) + q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) } else { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen))) } @@ -99,8 +100,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten cache.Put(ctx, k, v) k, v, mask := cache.Get(ctx) - q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + q = q.Permute(ctx, 0, 2, 1, 3) + k = k.Permute(ctx, 0, 2, 1, 3) v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) kq := k.Mulmat(ctx, q) @@ -144,12 +145,20 @@ type Layer struct { PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"` } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -170,6 +179,11 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { return nil, err } + outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) + if err != nil { + return nil, err + } + hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) @@ -182,7 +196,13 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { m.Cache.SetLayer(i) wc := m.Cache.(*kvcache.WrapperCache) wc.SetLayerType(cacheType) - hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) @@ -192,12 +212,6 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap)) hiddenState = hiddenState.Tanh(ctx) hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)) - - outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) - if err != nil { - return nil, err - } - return hiddenState.Rows(ctx, outputs), nil } diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7ee392766..0896f044b 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -66,9 +66,6 @@ func newTextModel(c ml.Config) *TextModel { }, } - slidingWindowLen := int32(c.Uint("text.attention.sliding_window")) - m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) - return &m } @@ -145,12 +142,20 @@ type TextLayer struct { PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"` } -func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { +func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts) hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps) + + // In the final layer (outputs != nil), optimize by pruning to just the token positions + // we need logits for. + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -181,7 +186,13 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu cache.SetLayer(i) wc := cache.(*kvcache.WrapperCache) wc.SetLayerType(cacheType) - hiddenState = layer.Forward(ctx, i, hiddenState, positions, cache, m.TextOptions) + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) @@ -190,7 +201,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outpu // final logit softcap hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextOptions.finalLogitSoftcap)) hiddenState = hiddenState.Tanh(ctx) - hiddenState = hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap)) - - return hiddenState.Rows(ctx, outputs) + return hiddenState.Scale(ctx, float64(m.TextOptions.finalLogitSoftcap)) } diff --git a/model/models/gemma3/model_vision.go b/model/models/gemma3/model_vision.go index 13cca3348..ee6e3b6f0 100644 --- a/model/models/gemma3/model_vision.go +++ b/model/models/gemma3/model_vision.go @@ -53,7 +53,7 @@ func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Visio } type VisionEncoderLayer struct { - LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"` + LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"` SelfAttention *VisionSelfAttention LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"` diff --git a/model/process_text_spm_test.go b/model/process_text_spm_test.go index 72bd629ce..13e28cc5f 100644 --- a/model/process_text_spm_test.go +++ b/model/process_text_spm_test.go @@ -73,7 +73,7 @@ func TestSentencePieceEncode(t *testing.T) { } for _, want := range cases { - ids, err := tokenizer.Encode(want) + ids, err := tokenizer.Encode(want, true) if err != nil { t.Fatal(err) } @@ -98,7 +98,7 @@ func TestSentencePieceEncode(t *testing.T) { } for _, want := range cases { - ids, err := tokenizer.Encode(want.token) + ids, err := tokenizer.Encode(want.token, true) if err != nil { t.Fatal(err) }