From 6da8b6a87902c8cf875028329ade512c8d859059 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 7 Mar 2025 14:20:31 -0800 Subject: [PATCH] kvcache: Support non-causal attention Models can disable causality for all or part of their processing while continuing to store data in the KV cache. --- kvcache/causal.go | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index 3d1c71db1..d519cf602 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e type Causal struct { DType ml.DType Capacity int32 + causal bool windowSize int32 // config controls mostly backend-specific optimizations @@ -42,6 +43,12 @@ type Causal struct { // locations in the cache that are needed for this batch curCellRange cellRange + // curSequences is the sequences corresponding to this pass's entries in the cache + curSequences []int + + // curPositions is the positions corresponding to this pass's entries in the cache + curPositions []int32 + // ** cache metadata ** // for each possible location in the cache, stores the position and set of sequences @@ -71,6 +78,7 @@ type cellRange struct { func NewCausalCache(shift shiftFn) *Causal { return &Causal{ + causal: true, windowSize: math.MaxInt32, shiftFn: shift, ctxs: make(map[int]ml.Context), @@ -81,6 +89,7 @@ func NewCausalCache(shift shiftFn) *Causal { func NewSWACache(windowSize int32, shift shiftFn) *Causal { return &Causal{ + causal: true, windowSize: windowSize, shiftFn: shift, ctxs: make(map[int]ml.Context), @@ -133,6 +142,8 @@ func (c *Causal) Close() { func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { c.curBatchSize = len(positions) + c.curSequences = seqs + c.curPositions = positions var err error c.curLoc, err = c.findStartLoc() @@ -171,7 +182,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err c.cellRanges[seq] = seqRange } - c.curMask, err = c.buildMask(ctx, positions, seqs) + c.curMask, err = c.buildMask(ctx) return err } @@ -212,7 +223,7 @@ func roundUp(length, pad int) int { // Builds a mask of history x batch indicating whether for each token in the batch the // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). -func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { +func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { // Align and pad the two dimensions as required by the backend batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) @@ -224,8 +235,9 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te for i := range c.curBatchSize { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { - if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || - c.cells[j].pos < positions[i]-c.windowSize { + if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || + (c.causal && c.cells[j].pos > c.curPositions[i]) || + c.cells[j].pos < c.curPositions[i]-c.windowSize { mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) } } @@ -391,6 +403,26 @@ func (c *Causal) SetLayer(layer int) { c.curLayer = layer } +// SetCausal enables or disables causal mask generation for subsequent calls to Get. +// This state carries over to future forward passes. The default value is true. +// +// ctx may be set to nil if this is called from outside of a forward pass, for +// example, when initializing the cache. +func (c *Causal) SetCausal(ctx ml.Context, causal bool) { + if c.causal != causal { + c.causal = causal + + if ctx != nil { + var err error + c.curMask, err = c.buildMask(ctx) + if err != nil { + // This error should never occur because we have previously built a mask with the same shape + panic(fmt.Errorf("SetCausal: %w", err)) + } + } + } +} + func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { key := c.keys[c.curLayer] value := c.values[c.curLayer]