kvcache: Clean up sliding window state with independent batches

Sliding windows models (e.g. gpt-oss, gemma3) remove tokens that
are out of the cache's window each time we start a new forward pass.

The cache storage needs to handle the window size for each sequence
plus the batch size, since the batch needs to attend to the full
window size. This means that we have greater than a window size
stored while processing the batch.

When the next batch comes, we are currently only looking at the
sequences in the incoming batch to slide the window forward.
However, we also need to clean up the other sequences that might
be occupying space in the batch processing buffer to ensure each
sequence is only using its window size of storage. Failure to do
this can result in "no kv cache slot found" errors.

Fixes: #10127
This commit is contained in:
Jesse Gross
2025-10-06 16:04:53 -07:00
committed by Jesse Gross
parent aa45f7ce27
commit 1fc35f1260
2 changed files with 139 additions and 28 deletions

View File

@@ -160,7 +160,15 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
if c.swaMemorySize == 0 {
c.swaMemorySize = c.swaWindowSize
}
if int(c.swaMemorySize) > capacity {
// We will allocate space in the cache for the stop token, which won't be part of a follow on
// sequence, so allocate an extra token of storage to ensure that we can jump back without
// causing a cache break. As an optimization, only do this when we have parallel sequences
// because the extra token will live in the batch buffer and won't get overwritten if we
// only have a single sequence.
if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
}
if int(c.swaMemorySize) >= capacity {
c.swaMemorySize = math.MaxInt32
}
@@ -214,7 +222,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curLoc, err = c.findStartLoc()
}
if err != nil {
slog.Warn("unable to find a kv cache slot", "cache", c)
return err
}
@@ -288,23 +295,44 @@ func (c *Causal) updateSlidingWindow() {
return
}
type lowestPosition struct {
pos int32
curBatch bool
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
lowestPos := make(map[int]lowestPosition)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
lowest, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
} else if c.curPositions[i] < lowest.pos {
lowest.pos = c.curPositions[i]
}
lowestPos[seq] = pos
lowestPos[seq] = lowest
}
// for any sequences are not part of this batch, clean up any tokens
// that are no longer needed after the processing of the previous
// batch
for seq, seqRange := range c.cellRanges {
if _, ok := lowestPos[seq]; !ok {
var last int32
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
last = max(last, c.cells[i].pos)
}
}
lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
}
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
for seq, lowest := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
@@ -314,13 +342,13 @@ func (c *Causal) updateSlidingWindow() {
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.swaMemorySize {
if c.cells[i].pos < lowest.pos-c.swaMemorySize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
if c.cells[i].pos >= pos-c.swaWindowSize {
if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
c.curCellRange.min = min(c.curCellRange.min, i)
c.curCellRange.max = max(c.curCellRange.max, i)
}
@@ -657,9 +685,11 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
// for sliding window, check that the window of the new sequence is contained in
// the window of what we are storing
var first int32 = math.MaxInt32
var last int32 = -1
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
first = min(first, c.cells[i].pos)
last = max(last, c.cells[i].pos)
}
}
@@ -668,10 +698,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
return false
}
lastWindowStart := max(0, last-c.swaMemorySize)
posWindowStart := max(0, pos-c.swaWindowSize)
return posWindowStart >= lastWindowStart
return posWindowStart >= first && pos <= last+1
}
func (c *Causal) shift(seq int, beginIndex, offset int32) error {

View File

@@ -96,6 +96,86 @@ func TestSWA(t *testing.T) {
testCache(t, backend, cache, tests)
}
func TestSWASeparateBatches(t *testing.T) {
backend := &testBackend{}
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 2, 16, 2)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "First seq 0",
in: []float32{1, 2},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{0, 1},
expected: []float32{1, 2},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 0",
in: []float32{3, 4},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{2, 3},
expected: []float32{2, 3, 4},
expectedShape: []int{1, 1, 3},
expectedMask: []float32{
0, 0, x,
x, 0, 0,
},
},
{
name: "First seq 1",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{0, 1},
expected: []float32{5, 6},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 1",
in: []float32{7, 8},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{2, 3},
expected: []float32{6, 3, 4, 7, 8},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, x, x, 0, x,
x, x, x, 0, 0,
},
},
{
name: "Third seq 0",
in: []float32{9, 10},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{9, 10, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, 0,
0, 0, x, x,
},
},
}
testCache(t, backend, cache, tests)
}
func TestSWAMem(t *testing.T) {
backend := &testBackend{}
cache := NewSWAMemCache(1, 3, nil)
@@ -431,15 +511,15 @@ func TestCanResume(t *testing.T) {
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
Positions: []int32{0, 1, 2, 3, 4},
Sequences: []int{0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
@@ -455,18 +535,21 @@ func TestCanResume(t *testing.T) {
if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
}
if !cache.CanResume(0, 4) {
t.Errorf("CanResume(0, 4) = false, want true (latest position)")
}
// shift window by adding position 4
// shift window by adding position 5
err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5},
Sequences: []int{0, 0},
Positions: []int32{5},
Sequences: []int{0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
@@ -503,28 +586,28 @@ func TestCanResumeSWAMem(t *testing.T) {
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4, 5},
Sequences: []int{0, 0, 0, 0, 0, 0},
Positions: []int32{0, 1, 2, 3, 4, 5, 6},
Sequences: []int{0, 0, 0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
cache.Put(context, tensor, tensor)
// shift window by adding position 6
// shift window by adding position 7
err = cache.StartForward(context, input.Batch{
Positions: []int32{6, 7},
Sequences: []int{0, 0},
Positions: []int32{7},
Sequences: []int{0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows