mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 21:57:30 +01:00
kvcache: Clean up sliding window state with independent batches
Sliding windows models (e.g. gpt-oss, gemma3) remove tokens that are out of the cache's window each time we start a new forward pass. The cache storage needs to handle the window size for each sequence plus the batch size, since the batch needs to attend to the full window size. This means that we have greater than a window size stored while processing the batch. When the next batch comes, we are currently only looking at the sequences in the incoming batch to slide the window forward. However, we also need to clean up the other sequences that might be occupying space in the batch processing buffer to ensure each sequence is only using its window size of storage. Failure to do this can result in "no kv cache slot found" errors. Fixes: #10127
This commit is contained in:
@@ -160,7 +160,15 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
if c.swaMemorySize == 0 {
|
||||
c.swaMemorySize = c.swaWindowSize
|
||||
}
|
||||
if int(c.swaMemorySize) > capacity {
|
||||
// We will allocate space in the cache for the stop token, which won't be part of a follow on
|
||||
// sequence, so allocate an extra token of storage to ensure that we can jump back without
|
||||
// causing a cache break. As an optimization, only do this when we have parallel sequences
|
||||
// because the extra token will live in the batch buffer and won't get overwritten if we
|
||||
// only have a single sequence.
|
||||
if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
|
||||
c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
|
||||
}
|
||||
if int(c.swaMemorySize) >= capacity {
|
||||
c.swaMemorySize = math.MaxInt32
|
||||
}
|
||||
|
||||
@@ -214,7 +222,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
slog.Warn("unable to find a kv cache slot", "cache", c)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -288,23 +295,44 @@ func (c *Causal) updateSlidingWindow() {
|
||||
return
|
||||
}
|
||||
|
||||
type lowestPosition struct {
|
||||
pos int32
|
||||
curBatch bool
|
||||
}
|
||||
|
||||
// create a map of unique sequences to the lowest position in that sequence
|
||||
lowestPos := make(map[int]int32)
|
||||
lowestPos := make(map[int]lowestPosition)
|
||||
for i := range c.curPositions {
|
||||
seq := c.curSequences[i]
|
||||
|
||||
pos, ok := lowestPos[seq]
|
||||
lowest, ok := lowestPos[seq]
|
||||
if !ok {
|
||||
pos = c.curPositions[i]
|
||||
} else if c.curPositions[i] < pos {
|
||||
pos = c.curPositions[i]
|
||||
lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
|
||||
} else if c.curPositions[i] < lowest.pos {
|
||||
lowest.pos = c.curPositions[i]
|
||||
}
|
||||
|
||||
lowestPos[seq] = pos
|
||||
lowestPos[seq] = lowest
|
||||
}
|
||||
|
||||
// for any sequences are not part of this batch, clean up any tokens
|
||||
// that are no longer needed after the processing of the previous
|
||||
// batch
|
||||
for seq, seqRange := range c.cellRanges {
|
||||
if _, ok := lowestPos[seq]; !ok {
|
||||
var last int32
|
||||
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
last = max(last, c.cells[i].pos)
|
||||
}
|
||||
}
|
||||
|
||||
lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
|
||||
}
|
||||
}
|
||||
|
||||
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||
for seq, pos := range lowestPos {
|
||||
for seq, lowest := range lowestPos {
|
||||
oldRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
continue
|
||||
@@ -314,13 +342,13 @@ func (c *Causal) updateSlidingWindow() {
|
||||
|
||||
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
if c.cells[i].pos < pos-c.swaMemorySize {
|
||||
if c.cells[i].pos < lowest.pos-c.swaMemorySize {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
} else {
|
||||
newRange.min = min(newRange.min, i)
|
||||
newRange.max = max(newRange.max, i)
|
||||
}
|
||||
if c.cells[i].pos >= pos-c.swaWindowSize {
|
||||
if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
|
||||
c.curCellRange.min = min(c.curCellRange.min, i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, i)
|
||||
}
|
||||
@@ -657,9 +685,11 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
|
||||
// for sliding window, check that the window of the new sequence is contained in
|
||||
// the window of what we are storing
|
||||
var first int32 = math.MaxInt32
|
||||
var last int32 = -1
|
||||
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
first = min(first, c.cells[i].pos)
|
||||
last = max(last, c.cells[i].pos)
|
||||
}
|
||||
}
|
||||
@@ -668,10 +698,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
lastWindowStart := max(0, last-c.swaMemorySize)
|
||||
posWindowStart := max(0, pos-c.swaWindowSize)
|
||||
|
||||
return posWindowStart >= lastWindowStart
|
||||
return posWindowStart >= first && pos <= last+1
|
||||
}
|
||||
|
||||
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
|
||||
@@ -96,6 +96,86 @@ func TestSWA(t *testing.T) {
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSWASeparateBatches(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewSWACache(1, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 2, 16, 2)
|
||||
|
||||
x := float32(math.Inf(-1))
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "First seq 0",
|
||||
in: []float32{1, 2},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{0, 1},
|
||||
expected: []float32{1, 2},
|
||||
expectedShape: []int{1, 1, 2},
|
||||
expectedMask: []float32{
|
||||
0, x,
|
||||
0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Second seq 0",
|
||||
in: []float32{3, 4},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{2, 3},
|
||||
expected: []float32{2, 3, 4},
|
||||
expectedShape: []int{1, 1, 3},
|
||||
expectedMask: []float32{
|
||||
0, 0, x,
|
||||
x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "First seq 1",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{1, 1},
|
||||
pos: []int32{0, 1},
|
||||
expected: []float32{5, 6},
|
||||
expectedShape: []int{1, 1, 2},
|
||||
expectedMask: []float32{
|
||||
0, x,
|
||||
0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Second seq 1",
|
||||
in: []float32{7, 8},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{1, 1},
|
||||
pos: []int32{2, 3},
|
||||
expected: []float32{6, 3, 4, 7, 8},
|
||||
expectedShape: []int{1, 1, 5},
|
||||
expectedMask: []float32{
|
||||
0, x, x, 0, x,
|
||||
x, x, x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Third seq 0",
|
||||
in: []float32{9, 10},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{9, 10, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{
|
||||
0, x, x, 0,
|
||||
0, 0, x, x,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSWAMem(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewSWAMemCache(1, 3, nil)
|
||||
@@ -431,15 +511,15 @@ func TestCanResume(t *testing.T) {
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3},
|
||||
Sequences: []int{0, 0, 0, 0},
|
||||
Positions: []int32{0, 1, 2, 3, 4},
|
||||
Sequences: []int{0, 0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// with window size 4, nothing has slid out of the window yet
|
||||
@@ -455,18 +535,21 @@ func TestCanResume(t *testing.T) {
|
||||
if !cache.CanResume(0, 3) {
|
||||
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
}
|
||||
if !cache.CanResume(0, 4) {
|
||||
t.Errorf("CanResume(0, 4) = false, want true (latest position)")
|
||||
}
|
||||
|
||||
// shift window by adding position 4
|
||||
// shift window by adding position 5
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{4, 5},
|
||||
Sequences: []int{0, 0},
|
||||
Positions: []int32{5},
|
||||
Sequences: []int{0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
@@ -503,28 +586,28 @@ func TestCanResumeSWAMem(t *testing.T) {
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3, 4, 5},
|
||||
Sequences: []int{0, 0, 0, 0, 0, 0},
|
||||
Positions: []int32{0, 1, 2, 3, 4, 5, 6},
|
||||
Sequences: []int{0, 0, 0, 0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// shift window by adding position 6
|
||||
// shift window by adding position 7
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{6, 7},
|
||||
Sequences: []int{0, 0},
|
||||
Positions: []int32{7},
|
||||
Sequences: []int{0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
|
||||
tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
|
||||
Reference in New Issue
Block a user