kvcache: Support non-causal attention

Models can disable causality for all or part of their processing
while continuing to store data in the KV cache.
This commit is contained in:
Jesse Gross 2025-03-07 14:20:31 -08:00 committed by Jesse Gross
parent 0daaaef8c9
commit 6da8b6a879

View File

@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type Causal struct {
DType ml.DType
Capacity int32
causal bool
windowSize int32
// config controls mostly backend-specific optimizations
@ -42,6 +43,12 @@ type Causal struct {
// locations in the cache that are needed for this batch
curCellRange cellRange
// curSequences is the sequences corresponding to this pass's entries in the cache
curSequences []int
// curPositions is the positions corresponding to this pass's entries in the cache
curPositions []int32
// ** cache metadata **
// for each possible location in the cache, stores the position and set of sequences
@ -71,6 +78,7 @@ type cellRange struct {
func NewCausalCache(shift shiftFn) *Causal {
return &Causal{
causal: true,
windowSize: math.MaxInt32,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
@ -81,6 +89,7 @@ func NewCausalCache(shift shiftFn) *Causal {
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{
causal: true,
windowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
@ -133,6 +142,8 @@ func (c *Causal) Close() {
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
c.curBatchSize = len(positions)
c.curSequences = seqs
c.curPositions = positions
var err error
c.curLoc, err = c.findStartLoc()
@ -171,7 +182,7 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
c.cellRanges[seq] = seqRange
}
c.curMask, err = c.buildMask(ctx, positions, seqs)
c.curMask, err = c.buildMask(ctx)
return err
}
@ -212,7 +223,7 @@ func roundUp(length, pad int) int {
// Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
// Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
@ -224,8 +235,9 @@ func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Te
for i := range c.curBatchSize {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
c.cells[j].pos < positions[i]-c.windowSize {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(c.causal && c.cells[j].pos > c.curPositions[i]) ||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
}
@ -391,6 +403,26 @@ func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
// This state carries over to future forward passes. The default value is true.
//
// ctx may be set to nil if this is called from outside of a forward pass, for
// example, when initializing the cache.
func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
if c.causal != causal {
c.causal = causal
if ctx != nil {
var err error
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
}
}
}
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]