From 764be7480f19f1749c518b21cead7c3a44c04b1d Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Fri, 25 Jul 2025 14:50:05 -0700 Subject: [PATCH] kvcache: Group shift operations into batches Currently, when we need to do a shift on the cache, it is one RoPE operation on the entire size of the cache (per layer). In some cases, this can create a compute graph that is larger than the forward pass since the forward pass is working in batches. Since we don't consider shifting in our memory estimates, it's possible for this to cause a crash if we run out of memory. By limiting the size of the RoPE calls to batch size chunks, we ensure that the shift will never exceed the size of the forward pass, since the forward pass will also contain a RoPE of the same size. This does not have a sigificant impact on performance since RoPE is a math operation that is mostly proportional to the size of its inputs. In theory defrag could have the same issue since it also creates a compute graph outside of the forward pass, however, since it is only copies, it does not require any working space. --- kvcache/causal.go | 79 ++++++++++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index b594d0b418..8b101a817b 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -25,6 +25,9 @@ type Causal struct { opts CausalOptions + // maxBatch is the largest batch that we might receive + maxBatch int + // config controls mostly backend-specific optimizations config *ml.CacheConfig @@ -147,6 +150,7 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity c.DType = dtype c.cellRanges = make(map[int]cellRange) c.backend = backend + c.maxBatch = maxBatch } func (c *Causal) SetConfig(config ml.CacheConfig) { @@ -639,48 +643,51 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { return ErrNotSupported } - ctx := c.backend.NewContext() - defer ctx.Close() - seqRange := c.cellRanges[seq] - size := seqRange.max - seqRange.min + 1 - offsets := make([]int32, size) - for i := range offsets { - cell := c.cells[seqRange.min+i] + for start := seqRange.min; start <= seqRange.max; start += c.maxBatch { + ctx := c.backend.NewContext() - if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { - offsets[i] = offset + size := min(seqRange.max-start+1, c.maxBatch) + offsets := make([]int32, size) + for i := range offsets { + cell := c.cells[start+i] + + if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { + offsets[i] = offset + } } + + kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) + + for i, key := range c.keys { + if key == nil { + continue + } + + kHeadDim := key.Dim(0) + numKVHeads := key.Dim(1) + rowSize := key.Stride(2) + + key = key.View(ctx, rowSize*start, + kHeadDim, key.Stride(1), + numKVHeads, key.Stride(2), + size, + ) + + roped, err := c.shiftFn(ctx, i, key, kShift) + if err != nil { + ctx.Close() + return err + } + + ctx.Forward(roped.Copy(ctx, key)) + } + + ctx.Compute() + ctx.Close() } - kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) - - for i, key := range c.keys { - if key == nil { - continue - } - - kHeadDim := key.Dim(0) - numKVHeads := key.Dim(1) - rowSize := key.Stride(2) - - key = key.View(ctx, rowSize*seqRange.min, - kHeadDim, key.Stride(1), - numKVHeads, key.Stride(2), - size, - ) - - roped, err := c.shiftFn(ctx, i, key, kShift) - if err != nil { - return err - } - - ctx.Forward(roped.Copy(ctx, key)) - } - - ctx.Compute() - return nil }