kvcache: Enable SWA to retain additional entries

Models that use sliding window attention can only resume a sequence
from the cache if it falls within the saved windows. This works well
if the next message picks up where the old one left off. However, it
generally prevents a partial prefix match unless the entire conversation
falls within the sliding window.

This can be a problem with reasoning models where the traces are
supposed to be removed from future messages, forcing the entire
history to be re-evaluated.

This change allows models to specify that a larger amount of the
history be retained in memory, to allow more partial resumption.
It still respects the window that the model was trained on for
token generation.
This commit is contained in:
Jesse Gross
2025-07-30 14:42:57 -07:00
committed by Jesse Gross
parent ff89ba90bc
commit 4183bb0574
2 changed files with 196 additions and 42 deletions

View File

@@ -19,9 +19,16 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The tensors are of shape embed dim, kv heads, batch size
// The mask is of shape history size, batch size
type Causal struct {
DType ml.DType
windowSize int32
chunkSize int32
DType ml.DType
// swaWindowSize is the number of tokens that will be included in the mask
// during attention operations. swaMemorySize is the number of tokens that
// will be retained in memory for partial prefix caching. Set to math.MaxInt32
// for unlimited or if sliding window attention is not being used.
swaWindowSize int32
swaMemorySize int32
chunkSize int32
opts CausalOptions
@@ -88,32 +95,41 @@ type cellRange struct {
func NewCausalCache(shift shiftFn) *Causal {
return &Causal{
windowSize: math.MaxInt32,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{
windowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
swaWindowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
return &Causal{
swaWindowSize: windowSize,
swaMemorySize: memorySize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
return &Causal{
windowSize: math.MaxInt32,
chunkSize: chunkSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
chunkSize: chunkSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
@@ -138,11 +154,25 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
c.config.MaskDType = ml.DTypeF32
}
if c.swaWindowSize == 0 {
c.swaWindowSize = math.MaxInt32
}
if c.swaMemorySize == 0 {
c.swaMemorySize = c.swaWindowSize
}
if int(c.swaMemorySize) > capacity {
c.swaMemorySize = math.MaxInt32
}
if c.swaMemorySize < c.swaWindowSize {
panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
}
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
if c.swaMemorySize == math.MaxInt32 {
cacheSize = maxSequences * capacity
} else {
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
@@ -187,7 +217,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
return err
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
@@ -198,19 +227,12 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
seqRange = newRange()
}
if c.curLoc+i > seqRange.max {
seqRange.max = c.curLoc + i
}
if seqRange.max > c.curCellRange.max {
c.curCellRange.max = seqRange.max
}
seqRange.min = min(seqRange.min, c.curLoc+i)
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
seqRange.max = max(seqRange.max, c.curLoc+i)
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
if c.curLoc+i < seqRange.min {
seqRange.min = c.curLoc + i
}
if seqRange.min < c.curCellRange.min {
c.curCellRange.min = seqRange.min
}
c.cellRanges[seq] = seqRange
}
} else {
@@ -252,7 +274,16 @@ func (c *Causal) findStartLoc() (int, error) {
}
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
c.curCellRange = newRange()
if c.swaMemorySize == math.MaxInt32 {
for _, seq := range c.curSequences {
if seqRange, ok := c.cellRanges[seq]; ok {
c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
}
}
return
}
@@ -282,12 +313,16 @@ func (c *Causal) updateSlidingWindow() {
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
if c.cells[i].pos < pos-c.swaMemorySize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
if c.cells[i].pos >= pos-c.swaWindowSize {
c.curCellRange.min = min(c.curCellRange.min, i)
c.curCellRange.max = max(c.curCellRange.max, i)
}
}
}
@@ -327,7 +362,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
}
@@ -485,6 +520,8 @@ func (c *Causal) defrag() {
c.cellRanges[seq] = seqRange
}
c.updateSlidingWindow()
}
func (c *Causal) SetLayer(layer int) {
@@ -610,7 +647,7 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
}
func (c *Causal) CanResume(seq int, pos int32) bool {
if c.windowSize == math.MaxInt32 {
if c.swaMemorySize == math.MaxInt32 {
return true
}
@@ -632,8 +669,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
return false
}
lastWindowStart := max(0, last-c.windowSize)
posWindowStart := max(0, pos-c.windowSize)
lastWindowStart := max(0, last-c.swaMemorySize)
posWindowStart := max(0, pos-c.swaWindowSize)
return posWindowStart >= lastWindowStart
}

View File

@@ -60,6 +60,8 @@ func TestSWA(t *testing.T) {
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "FirstBatch",
@@ -69,7 +71,12 @@ func TestSWA(t *testing.T) {
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, 0, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
@@ -79,7 +86,53 @@ func TestSWA(t *testing.T) {
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
expectedMask: []float32{
0, x, x, 0,
0, 0, x, x,
},
},
}
testCache(t, backend, cache, tests)
}
func TestSWAMem(t *testing.T) {
backend := &testBackend{}
cache := NewSWAMemCache(1, 3, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, 0, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{4, 5, 6},
expectedShape: []int{1, 1, 3},
expectedMask: []float32{
0, 0, x,
x, 0, 0,
},
},
}
@@ -437,6 +490,70 @@ func TestCanResume(t *testing.T) {
}
}
func TestCanResumeSWAMem(t *testing.T) {
backend := &testBackend{}
windowSize := int32(4)
memSize := int32(5)
cache := NewSWAMemCache(windowSize, memSize, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4, 5},
Sequences: []int{0, 0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
cache.Put(context, tensor, tensor)
// shift window by adding position 6
err = cache.StartForward(context, input.Batch{
Positions: []int32{6, 7},
Sequences: []int{0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
}
if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
}
if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
}
if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
}
if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
}
if cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
}
if !cache.CanResume(0, 6) {
t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
}
if !cache.CanResume(0, 7) {
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
}
}
type testBackend struct {
ml.Backend
}