From a8e83a7654fffa169b90fa927e6d19c4c0c765d7 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 10 Mar 2025 17:17:19 -0700 Subject: [PATCH] Disable causal attention based on batch index Currently we are using positions, which are relative to a sequence and may not be unique. --- kvcache/causal.go | 14 ++++++-------- model/models/gemma3/model_text.go | 8 ++++---- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index 020298005..edf6666da 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -144,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { c.curBatchSize = len(opts.Positions) c.curSequences = opts.Sequences c.curPositions = opts.Positions + c.opts.Except = nil var err error c.curLoc, err = c.findStartLoc() @@ -234,7 +235,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { mask := make([]float32, batchSize*length) for i := range c.curBatchSize { - enabled := !slices.Contains(c.opts.Except, c.curPositions[i]) + enabled := !slices.Contains(c.opts.Except, i) for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || (enabled && c.cells[j].pos > c.curPositions[i]) || @@ -405,15 +406,12 @@ func (c *Causal) SetLayer(layer int) { } type CausalOptions struct { - // Enabled controls whether the causal mask is generated for a particular position. - Except []int32 + // Enabled controls whether the causal mask is generated for a particular index in a batch + Except []int } -// SetCausal enables or disables causal mask generation for subsequent calls to Get. -// This state carries over to future forward passes. The default value is true. -// -// ctx may be set to nil if this is called from outside of a forward pass, for -// example, when initializing the cache. +// SetCausal disables causal mask generation for a particular range of indicies in +// the current batch for subsequent calls to Get. The state resets for the next forward pass. func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { if !slices.Equal(c.opts.Except, opts.Except) { c.opts = opts diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 2180571eb..5b5e2d6ed 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -173,10 +173,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, return hiddenState.Add(ctx, residual) } -func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 { +func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int { var embedding ml.Tensor var src, dst, length int - var except []int32 + var except []int for _, image := range multimodal { imageToken := image.Multimodal.(imageToken) @@ -204,7 +204,7 @@ func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []inpu length = 1 } - except = append(except, positions[imageDst]) + except = append(except, imageDst) } if embedding != nil { @@ -219,7 +219,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) - except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions) + except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal) for i, layer := range m.Layers { // gemma alternates between the sliding window (local) and causal (global)