From b42970063d8f05c47dd6d9a6b71f1e14cc4805c9 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sun, 30 Mar 2025 16:05:40 -0700 Subject: [PATCH] kvcache: Add check for values that fall out of sliding window cache The sliding window cache trims entries that are outside the window for the latest token. This works when we are extending the cache, such as when the conversation continues. However, if we have a partial overlap in conversation (including the BOS tokens), then we resume from a past point in the conversation and the needed tokens are no longer stored in memory. This verifies that the new window overlaps with the old one before reusing the cache. Co-authored-by: Jesse Gross --- kvcache/cache.go | 5 +++ kvcache/causal.go | 38 ++++++++++++++++- kvcache/causal_test.go | 71 +++++++++++++++++++++++++++++++ kvcache/encoder.go | 4 ++ kvcache/wrapper.go | 10 +++++ runner/ollamarunner/cache.go | 4 ++ runner/ollamarunner/cache_test.go | 1 + 7 files changed, 131 insertions(+), 2 deletions(-) diff --git a/kvcache/cache.go b/kvcache/cache.go index 18aec8003..07015b9e0 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -62,6 +62,11 @@ type Cache interface { // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) + // CanResume returns true if the cache can continue with the next token at + // the given position and sequence. Assumes that the caller has already + // verified the contents of the cache. + CanResume(seq int, pos int32) bool + // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set // endIndex to math.MaxInt32 to remove everything starting at beginIndex. // diff --git a/kvcache/causal.go b/kvcache/causal.go index fb4f0f743..4fc18d88f 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -581,6 +581,35 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { c.cellRanges[dstSeq] = seqRange } +func (c *Causal) CanResume(seq int, pos int32) bool { + if c.windowSize == math.MaxInt32 { + return true + } + + seqRange, ok := c.cellRanges[seq] + if !ok { + return false + } + + // for sliding window, check that the window of the new sequence is contained in + // the window of what we are storing + var last int32 = -1 + for i := seqRange.min; i <= seqRange.max; i++ { + if slices.Contains(c.cells[i].sequences, seq) { + last = max(last, c.cells[i].pos) + } + } + + if last == -1 { + return false + } + + lastWindowStart := max(0, last-c.windowSize) + posWindowStart := max(0, pos-c.windowSize) + + return posWindowStart >= lastWindowStart +} + func (c *Causal) shift(seq int, beginIndex, offset int32) error { if c.shiftFn == nil { return ErrNotSupported @@ -635,6 +664,12 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { } func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { + // TODO(jessegross): We should check to see if removing the middle of the sequence will + // cause the sliding window to encompass tokens that we no longer have. If so, then we + // should return an error, which will trigger the runner to evaluate the full history and + // rebuild the window. However, if we have multimodal inputs in our history, this reuse + // results in use after free, so we don't do it for now. + var offset int32 if endIndex != math.MaxInt32 { offset = beginIndex - endIndex @@ -649,8 +684,7 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { } else { if c.cells[i].pos >= endIndex { if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) { - // TODO(jessegross): Need to be careful about data shared between sequences - return errors.New("shifting on cells shared by multiple sequences not yet implemented") + return errors.New("shifting cells shared by multiple sequences not supported") } c.cells[i].pos += offset diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index b1dc7d779..bf98abef6 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -300,6 +300,77 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) } } +func TestCanResume(t *testing.T) { + backend := &testBackend{} + windowSize := int32(4) + cache := NewSWACache(windowSize, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 1, 16, 16) + + context := backend.NewContext() + defer context.Close() + + err := cache.StartForward(context, input.Batch{ + Positions: []int32{0, 1, 2, 3}, + Sequences: []int{0, 0, 0, 0}, + }) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } + + cache.SetLayer(0) + tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) + cache.Put(context, tensor, tensor) + + // with window size 4, nothing has slid out of the window yet + if !cache.CanResume(0, 0) { + t.Errorf("CanResume(0, 0) = false, want true (within window)") + } + if !cache.CanResume(0, 1) { + t.Errorf("CanResume(0, 1) = false, want true (within window)") + } + if !cache.CanResume(0, 2) { + t.Errorf("CanResume(0, 2) = false, want true (within window)") + } + if !cache.CanResume(0, 3) { + t.Errorf("CanResume(0, 3) = false, want true (latest position)") + } + + // shift window by adding position 4 + err = cache.StartForward(context, input.Batch{ + Positions: []int32{4, 5}, + Sequences: []int{0, 0}, + }) + if err != nil { + t.Fatalf("StartForward failed: %v", err) + } + + cache.SetLayer(0) + tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) + cache.Put(context, tensor, tensor) + + // only the latest position has overlapping windows + if cache.CanResume(0, 0) { + t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)") + } + if cache.CanResume(0, 1) { + t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)") + } + if cache.CanResume(0, 2) { + t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)") + } + if cache.CanResume(0, 3) { + t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)") + } + if cache.CanResume(0, 4) { + t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)") + } + if !cache.CanResume(0, 5) { + t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)") + } +} + type testBackend struct{} func (b *testBackend) Config() ml.Config { diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 07ff4291e..03d650a3f 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -134,6 +134,10 @@ func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { panic("encoder cache does not support multiple sequences") } +func (c *EncoderCache) CanResume(seq int, pos int32) bool { + return true +} + func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error { if c.encoderPos >= beginIndex && c.encoderPos < endIndex { c.encoderCached = false diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 0e8ff1f32..926bc2d41 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -87,6 +87,16 @@ func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) { } } +func (c *WrapperCache) CanResume(seq int, pos int32) bool { + for _, cache := range c.caches { + if !cache.CanResume(seq, pos) { + return false + } + } + + return true +} + func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error { // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail for _, cache := range c.caches { diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 30292f641..01f435e4b 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -118,6 +118,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp } if c.cache != nil { + if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) { + numPast = 0 + } + err = c.cache.Remove(slot.Id, numPast, math.MaxInt32) if err != nil { // Some models don't support partial erasure diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 6a8d8a6a9..543b4b2fa 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -451,6 +451,7 @@ func (m *mockCache) Close() func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil } func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} func (m *mockCache) SetConfig(ml.CacheConfig) {} +func (m *mockCache) CanResume(seq int, pos int32) bool { return true } func TestShiftCacheSlot(t *testing.T) { tests := []struct {