mirror of
https://github.com/ollama/ollama.git
synced 2025-04-03 09:29:49 +02:00
kvcache: Optimize sliding window attention
Currently sliding window attention allocates and uses the full context size and just masks out any tokens that are outside of the window. However, we really only need (roughly) the sliding window size. At large context sizes this improves two things: - Memory allocated - since the fully context size is allocated up front, memory requirements drop substantially. On Gemma3:4b with a 32k context window, total memory usage (including weights and non-sliding layers) drops from ~20GB to ~8GB. - Computation - ranges that are completely outside of the sliding window are now removed from the tensors that are returned from the cache rather than simply being masked out. This results in more efficient processing, scaling with the size of the context that has actually been used. Notable, this does not update the scheduler for any model to be aware of the smaller memory requirements. This is difficult for Gemma3 because the layers are heterogeneous between sliding and non-sliding attention. As a result, while actual memory consumption will be reduced, the scheduler will over-estimate the requirements of the model. This means that splitting between GPUs or GPUs and CPUs will still be suboptimal. Bug #9730
This commit is contained in:
parent
3ed7ad3ab3
commit
2d6eac9084
@ -118,7 +118,12 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
|
||||
cacheSize := maxSequences * capacity
|
||||
var cacheSize int
|
||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
|
||||
cacheSize = maxSequences * capacity
|
||||
} else {
|
||||
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
|
||||
}
|
||||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
c.cells = make([]cacheCell, cacheSize)
|
||||
|
||||
@ -147,6 +152,8 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
@ -214,6 +221,50 @@ func (c *Causal) findStartLoc() (int, error) {
|
||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
|
||||
}
|
||||
|
||||
func (c *Causal) updateSlidingWindow() {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
return
|
||||
}
|
||||
|
||||
// create a map of unique sequences to the lowest position in that sequence
|
||||
lowestPos := make(map[int]int32)
|
||||
for i := range c.curPositions {
|
||||
seq := c.curSequences[i]
|
||||
|
||||
pos, ok := lowestPos[seq]
|
||||
if !ok {
|
||||
pos = c.curPositions[i]
|
||||
} else if c.curPositions[i] < pos {
|
||||
pos = c.curPositions[i]
|
||||
}
|
||||
|
||||
lowestPos[seq] = pos
|
||||
}
|
||||
|
||||
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||
for seq, pos := range lowestPos {
|
||||
oldRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
newRange := newRange()
|
||||
|
||||
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
if c.cells[i].pos < pos-c.windowSize {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
} else {
|
||||
newRange.min = min(newRange.min, i)
|
||||
newRange.max = max(newRange.max, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.cellRanges[seq] = newRange
|
||||
}
|
||||
}
|
||||
|
||||
func roundDown(length, pad int) int {
|
||||
return (length / pad) * pad
|
||||
}
|
||||
|
@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
|
||||
cache := NewSWACache(1, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF32, 1, 16, 16)
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "SlidingWindow",
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
@ -71,6 +71,16 @@ func TestSWA(t *testing.T) {
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{5, 6, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
|
Loading…
x
Reference in New Issue
Block a user