kvcache: Clean up sliding window state with independent batches

Sliding windows models (e.g. gpt-oss, gemma3) remove tokens that
are out of the cache's window each time we start a new forward pass.

The cache storage needs to handle the window size for each sequence
plus the batch size, since the batch needs to attend to the full
window size. This means that we have greater than a window size
stored while processing the batch.

When the next batch comes, we are currently only looking at the
sequences in the incoming batch to slide the window forward.
However, we also need to clean up the other sequences that might
be occupying space in the batch processing buffer to ensure each
sequence is only using its window size of storage. Failure to do
this can result in "no kv cache slot found" errors.

Fixes: #10127
This commit is contained in:
Jesse Gross
2025-10-06 16:04:53 -07:00
committed by Jesse Gross
parent aa45f7ce27
commit 1fc35f1260
2 changed files with 139 additions and 28 deletions

View File

@@ -96,6 +96,86 @@ func TestSWA(t *testing.T) {
testCache(t, backend, cache, tests)
}
func TestSWASeparateBatches(t *testing.T) {
backend := &testBackend{}
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 2, 16, 2)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "First seq 0",
in: []float32{1, 2},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{0, 1},
expected: []float32{1, 2},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 0",
in: []float32{3, 4},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{2, 3},
expected: []float32{2, 3, 4},
expectedShape: []int{1, 1, 3},
expectedMask: []float32{
0, 0, x,
x, 0, 0,
},
},
{
name: "First seq 1",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{0, 1},
expected: []float32{5, 6},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 1",
in: []float32{7, 8},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{2, 3},
expected: []float32{6, 3, 4, 7, 8},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, x, x, 0, x,
x, x, x, 0, 0,
},
},
{
name: "Third seq 0",
in: []float32{9, 10},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{9, 10, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, 0,
0, 0, x, x,
},
},
}
testCache(t, backend, cache, tests)
}
func TestSWAMem(t *testing.T) {
backend := &testBackend{}
cache := NewSWAMemCache(1, 3, nil)
@@ -431,15 +511,15 @@ func TestCanResume(t *testing.T) {
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
Positions: []int32{0, 1, 2, 3, 4},
Sequences: []int{0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
@@ -455,18 +535,21 @@ func TestCanResume(t *testing.T) {
if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
}
if !cache.CanResume(0, 4) {
t.Errorf("CanResume(0, 4) = false, want true (latest position)")
}
// shift window by adding position 4
// shift window by adding position 5
err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5},
Sequences: []int{0, 0},
Positions: []int32{5},
Sequences: []int{0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
@@ -503,28 +586,28 @@ func TestCanResumeSWAMem(t *testing.T) {
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4, 5},
Sequences: []int{0, 0, 0, 0, 0, 0},
Positions: []int32{0, 1, 2, 3, 4, 5, 6},
Sequences: []int{0, 0, 0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
cache.Put(context, tensor, tensor)
// shift window by adding position 6
// shift window by adding position 7
err = cache.StartForward(context, input.Batch{
Positions: []int32{6, 7},
Sequences: []int{0, 0},
Positions: []int32{7},
Sequences: []int{0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows