From 3ed7ad3ab32b458aa2fdb8d0144c546efdb26a72 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 18 Mar 2025 14:31:52 -0700 Subject: [PATCH] kvcache: Pass granular cache size into implementations Currently the runner computes the kv size needed and creates a cache of that size. This is the context size times number of parallel sequences. Cache implementations can make better decisions about their memory usage, so instead pass in the required capacity, number of sequences and maximum batch size. For now, the causal cache just uses this to compute the size in the same way as before. --- kvcache/cache.go | 9 +++++++-- kvcache/causal.go | 33 +++++++++++++++++---------------- kvcache/causal_test.go | 12 ++++++------ kvcache/encoder.go | 6 +++++- kvcache/wrapper.go | 4 ++-- runner/ollamarunner/cache.go | 10 ++++++---- runner/ollamarunner/runner.go | 2 +- 7 files changed, 44 insertions(+), 32 deletions(-) diff --git a/kvcache/cache.go b/kvcache/cache.go index aa0a20562..18aec8003 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -43,8 +43,13 @@ type Cache interface { // ** cache management ** - // Init sets up runtime parameters - Init(backend ml.Backend, dtype ml.DType, capacity int32) + // Init sets up runtime parameters. + // backend: Used to allocate cache data storage and execute management operations (such as defrag) + // dtype: The data type for storing cache entries + // maxSequences: The maximum number of sequences stored in the cache - across all batches + // capacity: The number of cache entries to store, per sequence + // maxBatch: The maximum number of tokens that can occur in a single batch + Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) // Close closes the cache and frees resources associated with it Close() diff --git a/kvcache/causal.go b/kvcache/causal.go index e5216d588..ced409c33 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e // The mask is of shape history size, batch size type Causal struct { DType ml.DType - Capacity int32 windowSize int32 opts CausalOptions @@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal { } } -func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { +func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { if c.config == nil { var config ml.CacheConfig if cc, ok := backend.(ml.BackendCacheConfig); ok { @@ -119,9 +118,11 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { c.config.MaskDType = ml.DTypeF32 } + cacheSize := maxSequences * capacity + cacheSize = roundUp(cacheSize, c.config.CachePadding) + c.cells = make([]cacheCell, cacheSize) + c.DType = dtype - c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding)) - c.cells = make([]cacheCell, c.Capacity) c.cellRanges = make(map[int]cellRange) c.backend = backend } @@ -210,7 +211,7 @@ func (c *Causal) findStartLoc() (int, error) { } } - return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) + return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells)) } func roundDown(length, pad int) int { @@ -265,7 +266,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { return maskTensor, nil } -func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { +func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { for i, key := range c.keys { if key == nil { continue @@ -275,8 +276,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { numKVHeads := key.Dim(1) rowSize := key.Stride(2) - kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) - kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len) + kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length) + kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length) value := c.values[i] var vSrcView, vDstView ml.Tensor @@ -284,14 +285,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { vHeadDim := value.Dim(1) elemSize := value.Stride(0) - vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) - vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads) + vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads) } else { vHeadDim := value.Dim(0) rowSize := value.Stride(2) - vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len) - vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len) + vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length) + vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length) } ctx.Forward( @@ -480,14 +481,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { } if _, ok := c.keys[c.curLayer]; !ok { - c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) + c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells)) } if _, ok := c.values[c.curLayer]; !ok { if c.config.PermutedV { - c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads) } else { - c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells)) } } @@ -498,7 +499,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { elemSize := c.values[c.curLayer].Stride(0) value = value.Permute(ctx, 1, 2, 0, 3) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads))) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads))) } else { rowSize := c.values[c.curLayer].Stride(2) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 0f2385db7..66a2e835a 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -25,7 +25,7 @@ func TestStore(t *testing.T) { cache := NewCausalCache(nil) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { @@ -58,7 +58,7 @@ func TestSWA(t *testing.T) { cache := NewSWACache(1, nil) defer cache.Close() - cache.Init(backend, ml.DTypeF32, 16) + cache.Init(backend, ml.DTypeF32, 1, 16, 16) tests := []testCase{ { @@ -81,7 +81,7 @@ func TestSequences(t *testing.T) { cache := NewCausalCache(nil) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { @@ -116,7 +116,7 @@ func TestRemove(t *testing.T) { }) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { @@ -181,7 +181,7 @@ func TestDefrag(t *testing.T) { }) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { @@ -229,7 +229,7 @@ func TestCopy(t *testing.T) { cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 94c5d99c3..07ff4291e 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache { } } -func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { +func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { if c.config == nil { var config ml.CacheConfig if cc, ok := backend.(ml.BackendCacheConfig); ok { @@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) c.config = &config } + if maxSequences > 1 { + panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences)) + } + if c.config.CachePadding != 0 && c.config.CachePadding != 1 { panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) } diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index c85807a04..0e8ff1f32 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache { } } -func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { +func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { for _, cache := range c.caches { - cache.Init(backend, dtype, capacity) + cache.Init(backend, dtype, maxSequences, capacity, maxBatch) } } diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index cf5e6b911..aa56c9822 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -31,8 +31,10 @@ type InputCache struct { cache kvcache.Cache } -func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) { - if kvSize/int32(numSlots) < 1 { +func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) { + numCtx := kvSize / int32(numSlots) + + if numCtx < 1 { return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) } @@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots cache := model.Config().Cache if cache != nil { - cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize) + cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize) } return &InputCache{ - numCtx: kvSize / int32(numSlots), + numCtx: numCtx, enabled: cache != nil, slots: slots, multiUserCache: multiUserCache, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 90eb0de62..67d9a1b02 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -699,7 +699,7 @@ func (s *Server) loadModel( panic("loras are not yet implemented") } - s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache) + s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) if err != nil { panic(err) }