diff --git a/kvcache/causal.go b/kvcache/causal.go index ced409c33..aacaf540f 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -118,7 +118,12 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity c.config.MaskDType = ml.DTypeF32 } - cacheSize := maxSequences * capacity + var cacheSize int + if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch { + cacheSize = maxSequences * capacity + } else { + cacheSize = maxSequences * (int(c.windowSize) + maxBatch) + } cacheSize = roundUp(cacheSize, c.config.CachePadding) c.cells = make([]cacheCell, cacheSize) @@ -147,6 +152,8 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error { c.curPositions = batch.Positions c.opts.Except = nil + c.updateSlidingWindow() + var err error c.curLoc, err = c.findStartLoc() if errors.Is(err, ErrKvCacheFull) { @@ -214,6 +221,50 @@ func (c *Causal) findStartLoc() (int, error) { return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells)) } +func (c *Causal) updateSlidingWindow() { + if c.windowSize == math.MaxInt32 { + return + } + + // create a map of unique sequences to the lowest position in that sequence + lowestPos := make(map[int]int32) + for i := range c.curPositions { + seq := c.curSequences[i] + + pos, ok := lowestPos[seq] + if !ok { + pos = c.curPositions[i] + } else if c.curPositions[i] < pos { + pos = c.curPositions[i] + } + + lowestPos[seq] = pos + } + + // delete any entries that are beyond the window of the oldest position in the sequence + for seq, pos := range lowestPos { + oldRange, ok := c.cellRanges[seq] + if !ok { + continue + } + + newRange := newRange() + + for i := oldRange.min; i <= oldRange.max; i++ { + if slices.Contains(c.cells[i].sequences, seq) { + if c.cells[i].pos < pos-c.windowSize { + c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) + } else { + newRange.min = min(newRange.min, i) + newRange.max = max(newRange.max, i) + } + } + } + + c.cellRanges[seq] = newRange + } +} + func roundDown(length, pad int) int { return (length / pad) * pad } diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 66a2e835a..617f53635 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -58,11 +58,11 @@ func TestSWA(t *testing.T) { cache := NewSWACache(1, nil) defer cache.Close() - cache.Init(backend, ml.DTypeF32, 1, 16, 16) + cache.Init(backend, ml.DTypeF16, 1, 16, 16) tests := []testCase{ { - name: "SlidingWindow", + name: "FirstBatch", in: []float32{1, 2, 3, 4}, inShape: []int{1, 1, 4}, seqs: []int{0, 0, 0, 0}, @@ -71,6 +71,16 @@ func TestSWA(t *testing.T) { expectedShape: []int{1, 1, 4}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, }, + { + name: "SecondBatch", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{4, 5}, + expected: []float32{5, 6, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))}, + }, } testCache(t, backend, cache, tests)