mirror of
https://github.com/ollama/ollama.git
synced 2025-08-24 07:51:05 +02:00
kvcache: Enable SWA to retain additional entries
Models that use sliding window attention can only resume a sequence from the cache if it falls within the saved windows. This works well if the next message picks up where the old one left off. However, it generally prevents a partial prefix match unless the entire conversation falls within the sliding window. This can be a problem with reasoning models where the traces are supposed to be removed from future messages, forcing the entire history to be re-evaluated. This change allows models to specify that a larger amount of the history be retained in memory, to allow more partial resumption. It still respects the window that the model was trained on for token generation.
This commit is contained in:
@@ -19,9 +19,16 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
||||
// The tensors are of shape embed dim, kv heads, batch size
|
||||
// The mask is of shape history size, batch size
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
windowSize int32
|
||||
chunkSize int32
|
||||
DType ml.DType
|
||||
|
||||
// swaWindowSize is the number of tokens that will be included in the mask
|
||||
// during attention operations. swaMemorySize is the number of tokens that
|
||||
// will be retained in memory for partial prefix caching. Set to math.MaxInt32
|
||||
// for unlimited or if sliding window attention is not being used.
|
||||
swaWindowSize int32
|
||||
swaMemorySize int32
|
||||
|
||||
chunkSize int32
|
||||
|
||||
opts CausalOptions
|
||||
|
||||
@@ -88,32 +95,41 @@ type cellRange struct {
|
||||
|
||||
func NewCausalCache(shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
windowSize: math.MaxInt32,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
windowSize: windowSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
swaWindowSize: windowSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
swaWindowSize: windowSize,
|
||||
swaMemorySize: memorySize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
windowSize: math.MaxInt32,
|
||||
chunkSize: chunkSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
chunkSize: chunkSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
keys: make(map[int]ml.Tensor),
|
||||
values: make(map[int]ml.Tensor),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,11 +154,25 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
|
||||
if c.swaWindowSize == 0 {
|
||||
c.swaWindowSize = math.MaxInt32
|
||||
}
|
||||
if c.swaMemorySize == 0 {
|
||||
c.swaMemorySize = c.swaWindowSize
|
||||
}
|
||||
if int(c.swaMemorySize) > capacity {
|
||||
c.swaMemorySize = math.MaxInt32
|
||||
}
|
||||
|
||||
if c.swaMemorySize < c.swaWindowSize {
|
||||
panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
|
||||
}
|
||||
|
||||
var cacheSize int
|
||||
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
|
||||
if c.swaMemorySize == math.MaxInt32 {
|
||||
cacheSize = maxSequences * capacity
|
||||
} else {
|
||||
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
|
||||
cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
|
||||
}
|
||||
cacheSize = roundUp(cacheSize, c.config.CachePadding)
|
||||
c.cells = make([]cacheCell, cacheSize)
|
||||
@@ -187,7 +217,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
return err
|
||||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
|
||||
@@ -198,19 +227,12 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
if c.curLoc+i > seqRange.max {
|
||||
seqRange.max = c.curLoc + i
|
||||
}
|
||||
if seqRange.max > c.curCellRange.max {
|
||||
c.curCellRange.max = seqRange.max
|
||||
}
|
||||
seqRange.min = min(seqRange.min, c.curLoc+i)
|
||||
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
|
||||
|
||||
seqRange.max = max(seqRange.max, c.curLoc+i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
|
||||
|
||||
if c.curLoc+i < seqRange.min {
|
||||
seqRange.min = c.curLoc + i
|
||||
}
|
||||
if seqRange.min < c.curCellRange.min {
|
||||
c.curCellRange.min = seqRange.min
|
||||
}
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
} else {
|
||||
@@ -252,7 +274,16 @@ func (c *Causal) findStartLoc() (int, error) {
|
||||
}
|
||||
|
||||
func (c *Causal) updateSlidingWindow() {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
c.curCellRange = newRange()
|
||||
|
||||
if c.swaMemorySize == math.MaxInt32 {
|
||||
for _, seq := range c.curSequences {
|
||||
if seqRange, ok := c.cellRanges[seq]; ok {
|
||||
c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
|
||||
c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -282,12 +313,16 @@ func (c *Causal) updateSlidingWindow() {
|
||||
|
||||
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
if c.cells[i].pos < pos-c.windowSize {
|
||||
if c.cells[i].pos < pos-c.swaMemorySize {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
} else {
|
||||
newRange.min = min(newRange.min, i)
|
||||
newRange.max = max(newRange.max, i)
|
||||
}
|
||||
if c.cells[i].pos >= pos-c.swaWindowSize {
|
||||
c.curCellRange.min = min(c.curCellRange.min, i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,7 +362,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
|
||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||
c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
|
||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
@@ -485,6 +520,8 @@ func (c *Causal) defrag() {
|
||||
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
|
||||
c.updateSlidingWindow()
|
||||
}
|
||||
|
||||
func (c *Causal) SetLayer(layer int) {
|
||||
@@ -610,7 +647,7 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
}
|
||||
|
||||
func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
if c.windowSize == math.MaxInt32 {
|
||||
if c.swaMemorySize == math.MaxInt32 {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -632,8 +669,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
lastWindowStart := max(0, last-c.windowSize)
|
||||
posWindowStart := max(0, pos-c.windowSize)
|
||||
lastWindowStart := max(0, last-c.swaMemorySize)
|
||||
posWindowStart := max(0, pos-c.swaWindowSize)
|
||||
|
||||
return posWindowStart >= lastWindowStart
|
||||
}
|
||||
|
@@ -60,6 +60,8 @@ func TestSWA(t *testing.T) {
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
x := float32(math.Inf(-1))
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
@@ -69,7 +71,12 @@ func TestSWA(t *testing.T) {
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
expectedMask: []float32{
|
||||
0, x, x, x,
|
||||
0, 0, x, x,
|
||||
x, 0, 0, x,
|
||||
x, x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
@@ -79,7 +86,53 @@ func TestSWA(t *testing.T) {
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{5, 6, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
|
||||
expectedMask: []float32{
|
||||
0, x, x, 0,
|
||||
0, 0, x, x,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSWAMem(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewSWAMemCache(1, 3, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
x := float32(math.Inf(-1))
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{
|
||||
0, x, x, x,
|
||||
0, 0, x, x,
|
||||
x, 0, 0, x,
|
||||
x, x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{4, 5, 6},
|
||||
expectedShape: []int{1, 1, 3},
|
||||
expectedMask: []float32{
|
||||
0, 0, x,
|
||||
x, 0, 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -437,6 +490,70 @@ func TestCanResume(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanResumeSWAMem(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
windowSize := int32(4)
|
||||
memSize := int32(5)
|
||||
cache := NewSWAMemCache(windowSize, memSize, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3, 4, 5},
|
||||
Sequences: []int{0, 0, 0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// shift window by adding position 6
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{6, 7},
|
||||
Sequences: []int{0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
if cache.CanResume(0, 0) {
|
||||
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 1) {
|
||||
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 2) {
|
||||
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 3) {
|
||||
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 4) {
|
||||
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 5) {
|
||||
t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
|
||||
}
|
||||
if !cache.CanResume(0, 6) {
|
||||
t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
|
||||
}
|
||||
if !cache.CanResume(0, 7) {
|
||||
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
||||
}
|
||||
}
|
||||
|
||||
type testBackend struct {
|
||||
ml.Backend
|
||||
}
|
||||
|
Reference in New Issue
Block a user