diff --git a/kvcache/causal.go b/kvcache/causal.go index 31f5523310..543a65a60c 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -160,7 +160,15 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity if c.swaMemorySize == 0 { c.swaMemorySize = c.swaWindowSize } - if int(c.swaMemorySize) > capacity { + // We will allocate space in the cache for the stop token, which won't be part of a follow on + // sequence, so allocate an extra token of storage to ensure that we can jump back without + // causing a cache break. As an optimization, only do this when we have parallel sequences + // because the extra token will live in the batch buffer and won't get overwritten if we + // only have a single sequence. + if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 { + c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1) + } + if int(c.swaMemorySize) >= capacity { c.swaMemorySize = math.MaxInt32 } @@ -214,7 +222,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e c.curLoc, err = c.findStartLoc() } if err != nil { - slog.Warn("unable to find a kv cache slot", "cache", c) return err } @@ -288,23 +295,44 @@ func (c *Causal) updateSlidingWindow() { return } + type lowestPosition struct { + pos int32 + curBatch bool + } + // create a map of unique sequences to the lowest position in that sequence - lowestPos := make(map[int]int32) + lowestPos := make(map[int]lowestPosition) for i := range c.curPositions { seq := c.curSequences[i] - pos, ok := lowestPos[seq] + lowest, ok := lowestPos[seq] if !ok { - pos = c.curPositions[i] - } else if c.curPositions[i] < pos { - pos = c.curPositions[i] + lowest = lowestPosition{pos: c.curPositions[i], curBatch: true} + } else if c.curPositions[i] < lowest.pos { + lowest.pos = c.curPositions[i] } - lowestPos[seq] = pos + lowestPos[seq] = lowest + } + + // for any sequences are not part of this batch, clean up any tokens + // that are no longer needed after the processing of the previous + // batch + for seq, seqRange := range c.cellRanges { + if _, ok := lowestPos[seq]; !ok { + var last int32 + for i := seqRange.min; i <= seqRange.max; i++ { + if slices.Contains(c.cells[i].sequences, seq) { + last = max(last, c.cells[i].pos) + } + } + + lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false} + } } // delete any entries that are beyond the window of the oldest position in the sequence - for seq, pos := range lowestPos { + for seq, lowest := range lowestPos { oldRange, ok := c.cellRanges[seq] if !ok { continue @@ -314,13 +342,13 @@ func (c *Causal) updateSlidingWindow() { for i := oldRange.min; i <= oldRange.max; i++ { if slices.Contains(c.cells[i].sequences, seq) { - if c.cells[i].pos < pos-c.swaMemorySize { + if c.cells[i].pos < lowest.pos-c.swaMemorySize { c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) } else { newRange.min = min(newRange.min, i) newRange.max = max(newRange.max, i) } - if c.cells[i].pos >= pos-c.swaWindowSize { + if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize { c.curCellRange.min = min(c.curCellRange.min, i) c.curCellRange.max = max(c.curCellRange.max, i) } @@ -657,9 +685,11 @@ func (c *Causal) CanResume(seq int, pos int32) bool { // for sliding window, check that the window of the new sequence is contained in // the window of what we are storing + var first int32 = math.MaxInt32 var last int32 = -1 for i := seqRange.min; i <= seqRange.max; i++ { if slices.Contains(c.cells[i].sequences, seq) { + first = min(first, c.cells[i].pos) last = max(last, c.cells[i].pos) } } @@ -668,10 +698,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool { return false } - lastWindowStart := max(0, last-c.swaMemorySize) posWindowStart := max(0, pos-c.swaWindowSize) - - return posWindowStart >= lastWindowStart + return posWindowStart >= first && pos <= last+1 } func (c *Causal) shift(seq int, beginIndex, offset int32) error { diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 0d8cea79f7..7e4fc3b109 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -96,6 +96,86 @@ func TestSWA(t *testing.T) { testCache(t, backend, cache, tests) } +func TestSWASeparateBatches(t *testing.T) { + backend := &testBackend{} + cache := NewSWACache(1, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 2, 16, 2) + + x := float32(math.Inf(-1)) + + tests := []testCase{ + { + name: "First seq 0", + in: []float32{1, 2}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{0, 1}, + expected: []float32{1, 2}, + expectedShape: []int{1, 1, 2}, + expectedMask: []float32{ + 0, x, + 0, 0, + }, + }, + { + name: "Second seq 0", + in: []float32{3, 4}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{2, 3}, + expected: []float32{2, 3, 4}, + expectedShape: []int{1, 1, 3}, + expectedMask: []float32{ + 0, 0, x, + x, 0, 0, + }, + }, + { + name: "First seq 1", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{1, 1}, + pos: []int32{0, 1}, + expected: []float32{5, 6}, + expectedShape: []int{1, 1, 2}, + expectedMask: []float32{ + 0, x, + 0, 0, + }, + }, + { + name: "Second seq 1", + in: []float32{7, 8}, + inShape: []int{1, 1, 2}, + seqs: []int{1, 1}, + pos: []int32{2, 3}, + expected: []float32{6, 3, 4, 7, 8}, + expectedShape: []int{1, 1, 5}, + expectedMask: []float32{ + 0, x, x, 0, x, + x, x, x, 0, 0, + }, + }, + { + name: "Third seq 0", + in: []float32{9, 10}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{4, 5}, + expected: []float32{9, 10, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{ + 0, x, x, 0, + 0, 0, x, x, + }, + }, + } + + testCache(t, backend, cache, tests) +} + func TestSWAMem(t *testing.T) { backend := &testBackend{} cache := NewSWAMemCache(1, 3, nil) @@ -431,15 +511,15 @@ func TestCanResume(t *testing.T) { defer context.Close() err := cache.StartForward(context, input.Batch{ - Positions: []int32{0, 1, 2, 3}, - Sequences: []int{0, 0, 0, 0}, + Positions: []int32{0, 1, 2, 3, 4}, + Sequences: []int{0, 0, 0, 0, 0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) - tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) + tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5) cache.Put(context, tensor, tensor) // with window size 4, nothing has slid out of the window yet @@ -455,18 +535,21 @@ func TestCanResume(t *testing.T) { if !cache.CanResume(0, 3) { t.Errorf("CanResume(0, 3) = false, want true (latest position)") } + if !cache.CanResume(0, 4) { + t.Errorf("CanResume(0, 4) = false, want true (latest position)") + } - // shift window by adding position 4 + // shift window by adding position 5 err = cache.StartForward(context, input.Batch{ - Positions: []int32{4, 5}, - Sequences: []int{0, 0}, + Positions: []int32{5}, + Sequences: []int{0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) - tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) + tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows @@ -503,28 +586,28 @@ func TestCanResumeSWAMem(t *testing.T) { defer context.Close() err := cache.StartForward(context, input.Batch{ - Positions: []int32{0, 1, 2, 3, 4, 5}, - Sequences: []int{0, 0, 0, 0, 0, 0}, + Positions: []int32{0, 1, 2, 3, 4, 5, 6}, + Sequences: []int{0, 0, 0, 0, 0, 0, 0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) - tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6) + tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) cache.Put(context, tensor, tensor) - // shift window by adding position 6 + // shift window by adding position 7 err = cache.StartForward(context, input.Batch{ - Positions: []int32{6, 7}, - Sequences: []int{0, 0}, + Positions: []int32{7}, + Sequences: []int{0}, }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } cache.SetLayer(0) - tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2) + tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows