From 4183bb0574a28b73276efef944107d0c45d79c95 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 30 Jul 2025 14:42:57 -0700 Subject: [PATCH] kvcache: Enable SWA to retain additional entries Models that use sliding window attention can only resume a sequence from the cache if it falls within the saved windows. This works well if the next message picks up where the old one left off. However, it generally prevents a partial prefix match unless the entire conversation falls within the sliding window. This can be a problem with reasoning models where the traces are supposed to be removed from future messages, forcing the entire history to be re-evaluated. This change allows models to specify that a larger amount of the history be retained in memory, to allow more partial resumption. It still respects the window that the model was trained on for token generation. --- kvcache/causal.go | 117 +++++++++++++++++++++++++-------------- kvcache/causal_test.go | 121 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 196 insertions(+), 42 deletions(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index 496eeaa64b..56c9360031 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -19,9 +19,16 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e // The tensors are of shape embed dim, kv heads, batch size // The mask is of shape history size, batch size type Causal struct { - DType ml.DType - windowSize int32 - chunkSize int32 + DType ml.DType + + // swaWindowSize is the number of tokens that will be included in the mask + // during attention operations. swaMemorySize is the number of tokens that + // will be retained in memory for partial prefix caching. Set to math.MaxInt32 + // for unlimited or if sliding window attention is not being used. + swaWindowSize int32 + swaMemorySize int32 + + chunkSize int32 opts CausalOptions @@ -88,32 +95,41 @@ type cellRange struct { func NewCausalCache(shift shiftFn) *Causal { return &Causal{ - windowSize: math.MaxInt32, - shiftFn: shift, - ctxs: make(map[int]ml.Context), - keys: make(map[int]ml.Tensor), - values: make(map[int]ml.Tensor), + shiftFn: shift, + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), } } func NewSWACache(windowSize int32, shift shiftFn) *Causal { return &Causal{ - windowSize: windowSize, - shiftFn: shift, - ctxs: make(map[int]ml.Context), - keys: make(map[int]ml.Tensor), - values: make(map[int]ml.Tensor), + swaWindowSize: windowSize, + shiftFn: shift, + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + } +} + +func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal { + return &Causal{ + swaWindowSize: windowSize, + swaMemorySize: memorySize, + shiftFn: shift, + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), } } func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal { return &Causal{ - windowSize: math.MaxInt32, - chunkSize: chunkSize, - shiftFn: shift, - ctxs: make(map[int]ml.Context), - keys: make(map[int]ml.Tensor), - values: make(map[int]ml.Tensor), + chunkSize: chunkSize, + shiftFn: shift, + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), } } @@ -138,11 +154,25 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity c.config.MaskDType = ml.DTypeF32 } + if c.swaWindowSize == 0 { + c.swaWindowSize = math.MaxInt32 + } + if c.swaMemorySize == 0 { + c.swaMemorySize = c.swaWindowSize + } + if int(c.swaMemorySize) > capacity { + c.swaMemorySize = math.MaxInt32 + } + + if c.swaMemorySize < c.swaWindowSize { + panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize)) + } + var cacheSize int - if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) { + if c.swaMemorySize == math.MaxInt32 { cacheSize = maxSequences * capacity } else { - cacheSize = (maxSequences * int(c.windowSize)) + maxBatch + cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch } cacheSize = roundUp(cacheSize, c.config.CachePadding) c.cells = make([]cacheCell, cacheSize) @@ -187,7 +217,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e return err } - c.curCellRange = newRange() for i, pos := range batch.Positions { seq := batch.Sequences[i] @@ -198,19 +227,12 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e seqRange = newRange() } - if c.curLoc+i > seqRange.max { - seqRange.max = c.curLoc + i - } - if seqRange.max > c.curCellRange.max { - c.curCellRange.max = seqRange.max - } + seqRange.min = min(seqRange.min, c.curLoc+i) + c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i) + + seqRange.max = max(seqRange.max, c.curLoc+i) + c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i) - if c.curLoc+i < seqRange.min { - seqRange.min = c.curLoc + i - } - if seqRange.min < c.curCellRange.min { - c.curCellRange.min = seqRange.min - } c.cellRanges[seq] = seqRange } } else { @@ -252,7 +274,16 @@ func (c *Causal) findStartLoc() (int, error) { } func (c *Causal) updateSlidingWindow() { - if c.windowSize == math.MaxInt32 { + c.curCellRange = newRange() + + if c.swaMemorySize == math.MaxInt32 { + for _, seq := range c.curSequences { + if seqRange, ok := c.cellRanges[seq]; ok { + c.curCellRange.min = min(c.curCellRange.min, seqRange.min) + c.curCellRange.max = max(c.curCellRange.max, seqRange.max) + } + } + return } @@ -282,12 +313,16 @@ func (c *Causal) updateSlidingWindow() { for i := oldRange.min; i <= oldRange.max; i++ { if slices.Contains(c.cells[i].sequences, seq) { - if c.cells[i].pos < pos-c.windowSize { + if c.cells[i].pos < pos-c.swaMemorySize { c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) } else { newRange.min = min(newRange.min, i) newRange.max = max(newRange.max, i) } + if c.cells[i].pos >= pos-c.swaWindowSize { + c.curCellRange.min = min(c.curCellRange.min, i) + c.curCellRange.max = max(c.curCellRange.max, i) + } } } @@ -327,7 +362,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || (enabled && c.cells[j].pos > c.curPositions[i]) || c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize || - c.cells[j].pos < c.curPositions[i]-c.windowSize { + c.cells[j].pos < c.curPositions[i]-c.swaWindowSize { mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) } } @@ -485,6 +520,8 @@ func (c *Causal) defrag() { c.cellRanges[seq] = seqRange } + + c.updateSlidingWindow() } func (c *Causal) SetLayer(layer int) { @@ -610,7 +647,7 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { } func (c *Causal) CanResume(seq int, pos int32) bool { - if c.windowSize == math.MaxInt32 { + if c.swaMemorySize == math.MaxInt32 { return true } @@ -632,8 +669,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool { return false } - lastWindowStart := max(0, last-c.windowSize) - posWindowStart := max(0, pos-c.windowSize) + lastWindowStart := max(0, last-c.swaMemorySize) + posWindowStart := max(0, pos-c.swaWindowSize) return posWindowStart >= lastWindowStart } diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 5b1dbe868f..0d8cea79f7 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -60,6 +60,8 @@ func TestSWA(t *testing.T) { cache.Init(backend, ml.DTypeF16, 1, 16, 16) + x := float32(math.Inf(-1)) + tests := []testCase{ { name: "FirstBatch", @@ -69,7 +71,12 @@ func TestSWA(t *testing.T) { pos: []int32{0, 1, 2, 3}, expected: []float32{1, 2, 3, 4}, expectedShape: []int{1, 1, 4}, - expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + expectedMask: []float32{ + 0, x, x, x, + 0, 0, x, x, + x, 0, 0, x, + x, x, 0, 0, + }, }, { name: "SecondBatch", @@ -79,7 +86,53 @@ func TestSWA(t *testing.T) { pos: []int32{4, 5}, expected: []float32{5, 6, 3, 4}, expectedShape: []int{1, 1, 4}, - expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))}, + expectedMask: []float32{ + 0, x, x, 0, + 0, 0, x, x, + }, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestSWAMem(t *testing.T) { + backend := &testBackend{} + cache := NewSWAMemCache(1, 3, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 1, 16, 16) + + x := float32(math.Inf(-1)) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{ + 0, x, x, x, + 0, 0, x, x, + x, 0, 0, x, + x, x, 0, 0, + }, + }, + { + name: "SecondBatch", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{4, 5}, + expected: []float32{4, 5, 6}, + expectedShape: []int{1, 1, 3}, + expectedMask: []float32{ + 0, 0, x, + x, 0, 0, + }, }, } @@ -437,6 +490,70 @@ func TestCanResume(t *testing.T) { } } +func TestCanResumeSWAMem(t *testing.T) { + backend := &testBackend{} + windowSize := int32(4) + memSize := int32(5) + cache := NewSWAMemCache(windowSize, memSize, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 1, 16, 16) + + context := backend.NewContext() + defer context.Close() + + err := cache.StartForward(context, input.Batch{ + Positions: []int32{0, 1, 2, 3, 4, 5}, + Sequences: []int{0, 0, 0, 0, 0, 0}, + }, false) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } + + cache.SetLayer(0) + tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6) + cache.Put(context, tensor, tensor) + + // shift window by adding position 6 + err = cache.StartForward(context, input.Batch{ + Positions: []int32{6, 7}, + Sequences: []int{0, 0}, + }, false) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } + + cache.SetLayer(0) + tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2) + cache.Put(context, tensor, tensor) + + // only the latest position has overlapping windows + if cache.CanResume(0, 0) { + t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") + } + if cache.CanResume(0, 1) { + t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") + } + if cache.CanResume(0, 2) { + t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") + } + if cache.CanResume(0, 3) { + t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") + } + if cache.CanResume(0, 4) { + t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") + } + if cache.CanResume(0, 5) { + t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)") + } + if !cache.CanResume(0, 6) { + t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)") + } + if !cache.CanResume(0, 7) { + t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)") + } +} + type testBackend struct { ml.Backend }