diff --git a/kvcache/cache.go b/kvcache/cache.go index aa0a20562..18aec8003 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -43,8 +43,13 @@ type Cache interface { // ** cache management ** - // Init sets up runtime parameters - Init(backend ml.Backend, dtype ml.DType, capacity int32) + // Init sets up runtime parameters. + // backend: Used to allocate cache data storage and execute management operations (such as defrag) + // dtype: The data type for storing cache entries + // maxSequences: The maximum number of sequences stored in the cache - across all batches + // capacity: The number of cache entries to store, per sequence + // maxBatch: The maximum number of tokens that can occur in a single batch + Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) // Close closes the cache and frees resources associated with it Close() diff --git a/kvcache/causal.go b/kvcache/causal.go index e5216d588..ced409c33 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -20,7 +20,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e // The mask is of shape history size, batch size type Causal struct { DType ml.DType - Capacity int32 windowSize int32 opts CausalOptions @@ -98,7 +97,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal { } } -func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { +func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { if c.config == nil { var config ml.CacheConfig if cc, ok := backend.(ml.BackendCacheConfig); ok { @@ -119,9 +118,11 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { c.config.MaskDType = ml.DTypeF32 } + cacheSize := maxSequences * capacity + cacheSize = roundUp(cacheSize, c.config.CachePadding) + c.cells = make([]cacheCell, cacheSize) + c.DType = dtype - c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding)) - c.cells = make([]cacheCell, c.Capacity) c.cellRanges = make(map[int]cellRange) c.backend = backend } @@ -210,7 +211,7 @@ func (c *Causal) findStartLoc() (int, error) { } } - return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) + return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells)) } func roundDown(length, pad int) int { @@ -265,7 +266,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { return maskTensor, nil } -func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { +func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { for i, key := range c.keys { if key == nil { continue @@ -275,8 +276,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { numKVHeads := key.Dim(1) rowSize := key.Stride(2) - kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) - kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len) + kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length) + kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length) value := c.values[i] var vSrcView, vDstView ml.Tensor @@ -284,14 +285,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { vHeadDim := value.Dim(1) elemSize := value.Stride(0) - vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) - vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads) + vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads) } else { vHeadDim := value.Dim(0) rowSize := value.Stride(2) - vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len) - vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len) + vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length) + vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length) } ctx.Forward( @@ -480,14 +481,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { } if _, ok := c.keys[c.curLayer]; !ok { - c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) + c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells)) } if _, ok := c.values[c.curLayer]; !ok { if c.config.PermutedV { - c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads) } else { - c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells)) } } @@ -498,7 +499,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { elemSize := c.values[c.curLayer].Stride(0) value = value.Permute(ctx, 1, 2, 0, 3) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads))) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads))) } else { rowSize := c.values[c.curLayer].Stride(2) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 0f2385db7..66a2e835a 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -25,7 +25,7 @@ func TestStore(t *testing.T) { cache := NewCausalCache(nil) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { @@ -58,7 +58,7 @@ func TestSWA(t *testing.T) { cache := NewSWACache(1, nil) defer cache.Close() - cache.Init(backend, ml.DTypeF32, 16) + cache.Init(backend, ml.DTypeF32, 1, 16, 16) tests := []testCase{ { @@ -81,7 +81,7 @@ func TestSequences(t *testing.T) { cache := NewCausalCache(nil) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { @@ -116,7 +116,7 @@ func TestRemove(t *testing.T) { }) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { @@ -181,7 +181,7 @@ func TestDefrag(t *testing.T) { }) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { @@ -229,7 +229,7 @@ func TestCopy(t *testing.T) { cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) defer cache.Close() - cache.Init(backend, ml.DTypeF16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 94c5d99c3..07ff4291e 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache { } } -func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { +func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { if c.config == nil { var config ml.CacheConfig if cc, ok := backend.(ml.BackendCacheConfig); ok { @@ -58,6 +58,10 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) c.config = &config } + if maxSequences > 1 { + panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences)) + } + if c.config.CachePadding != 0 && c.config.CachePadding != 1 { panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding)) } diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index c85807a04..0e8ff1f32 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache { } } -func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { +func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { for _, cache := range c.caches { - cache.Init(backend, dtype, capacity) + cache.Init(backend, dtype, maxSequences, capacity, maxBatch) } } diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index cf5e6b911..aa56c9822 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -31,8 +31,10 @@ type InputCache struct { cache kvcache.Cache } -func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) { - if kvSize/int32(numSlots) < 1 { +func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) { + numCtx := kvSize / int32(numSlots) + + if numCtx < 1 { return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) } @@ -44,11 +46,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots cache := model.Config().Cache if cache != nil { - cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize) + cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize) } return &InputCache{ - numCtx: kvSize / int32(numSlots), + numCtx: numCtx, enabled: cache != nil, slots: slots, multiUserCache: multiUserCache, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 90eb0de62..67d9a1b02 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -699,7 +699,7 @@ func (s *Server) loadModel( panic("loras are not yet implemented") } - s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache) + s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) if err != nil { panic(err) }