kvcache: Run tests both with and without PermutedV

The causal cache can store data differently depending on what is
best for the backend. We should run tests both ways.
This commit is contained in:
Jesse Gross
2025-11-19 10:44:38 -08:00
committed by Jesse Gross
parent b2af50960f
commit cb485b2019

View File

@@ -1,6 +1,7 @@
package kvcache
import (
"fmt"
"math"
"slices"
"testing"
@@ -20,8 +21,17 @@ type testCase struct {
expectedMask []float32
}
func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
t.Helper()
for _, permuted := range []bool{false, true} {
t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
fn(t, &testBackend{permutedV: permuted})
})
}
}
func TestStore(t *testing.T) {
backend := &testBackend{}
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(nil)
defer cache.Close()
@@ -51,10 +61,11 @@ func TestStore(t *testing.T) {
}
testCache(t, backend, cache, tests)
})
}
func TestSWA(t *testing.T) {
backend := &testBackend{}
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewSWACache(1, nil)
defer cache.Close()
@@ -94,10 +105,11 @@ func TestSWA(t *testing.T) {
}
testCache(t, backend, cache, tests)
})
}
func TestSWASeparateBatches(t *testing.T) {
backend := &testBackend{}
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewSWACache(1, nil)
defer cache.Close()
@@ -174,10 +186,11 @@ func TestSWASeparateBatches(t *testing.T) {
}
testCache(t, backend, cache, tests)
})
}
func TestSWAMem(t *testing.T) {
backend := &testBackend{}
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewSWAMemCache(1, 3, nil)
defer cache.Close()
@@ -217,19 +230,20 @@ func TestSWAMem(t *testing.T) {
}
testCache(t, backend, cache, tests)
})
}
func TestChunkedAttention(t *testing.T) {
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
var b testBackend
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
testCache(
t, &b, cache,
t, backend, cache,
[]testCase{
{
name: "FirstBatch",
@@ -275,10 +289,11 @@ func TestChunkedAttention(t *testing.T) {
},
},
)
})
}
func TestSequences(t *testing.T) {
backend := &testBackend{}
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(nil)
defer cache.Close()
@@ -308,10 +323,11 @@ func TestSequences(t *testing.T) {
}
testCache(t, backend, cache, tests)
})
}
func TestRemove(t *testing.T) {
backend := &testBackend{}
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.Add(ctx, shift), nil
})
@@ -386,10 +402,11 @@ func TestRemove(t *testing.T) {
}
testCache(t, backend, cache, tests)
})
}
func TestCopy(t *testing.T) {
backend := &testBackend{}
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close()
@@ -426,6 +443,7 @@ func TestCopy(t *testing.T) {
}
testCache(t, backend, cache, tests)
})
}
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
@@ -463,7 +481,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
}
func TestCanResume(t *testing.T) {
backend := &testBackend{}
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
windowSize := int32(4)
cache := NewSWACache(windowSize, nil)
defer cache.Close()
@@ -534,10 +552,11 @@ func TestCanResume(t *testing.T) {
if !cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
}
})
}
func TestCanResumeSWAMem(t *testing.T) {
backend := &testBackend{}
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
windowSize := int32(4)
memSize := int32(5)
cache := NewSWAMemCache(windowSize, memSize, nil)
@@ -598,10 +617,12 @@ func TestCanResumeSWAMem(t *testing.T) {
if !cache.CanResume(0, 7) {
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
}
})
}
type testBackend struct {
ml.Backend
permutedV bool
}
func (b *testBackend) NewContext() ml.Context {
@@ -612,6 +633,10 @@ func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{}
}
func (b *testBackend) CacheConfig() ml.CacheConfig {
return ml.CacheConfig{PermutedV: b.permutedV}
}
type testContext struct {
ml.Context
}
@@ -766,6 +791,102 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
return view
}
func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
if len(t.shape) > 4 || len(order) > 4 {
panic("permute only supports up to 4 dimensions")
}
if len(order) != len(t.shape) && len(order) != 4 {
panic("invalid number of dimensions for permute")
}
// ggml_permute expects 4 axes, so fill in any missing dimensions.
orderFull := append(make([]int, 0, 4), order...)
for len(orderFull) < 4 {
orderFull = append(orderFull, len(orderFull))
}
seen := [4]bool{}
shape4 := [4]int{1, 1, 1, 1}
for i := 0; i < len(t.shape) && i < 4; i++ {
shape4[i] = t.shape[i]
}
newShape4 := [4]int{1, 1, 1, 1}
for axis := range 4 {
dst := orderFull[axis]
if dst < 0 || dst >= 4 {
panic("invalid axis for permute")
}
if seen[dst] {
panic("duplicate axis for permute")
}
seen[dst] = true
newShape4[dst] = shape4[axis]
}
total := len(t.data)
newData := make([]float32, total)
if total > 0 {
oldDims := shape4
newDims := newShape4
oldStride := [4]int{1, 1, 1, 1}
newStride := [4]int{1, 1, 1, 1}
for i := 1; i < 4; i++ {
oldStride[i] = oldStride[i-1] * oldDims[i-1]
newStride[i] = newStride[i-1] * newDims[i-1]
}
var coords [4]int
var newCoords [4]int
for idx := range total {
remainder := idx
for axis := range 4 {
dim := oldDims[axis]
if dim == 0 {
coords[axis] = 0
continue
}
coords[axis] = remainder % dim
remainder /= dim
}
for axis := range 4 {
newCoords[orderFull[axis]] = coords[axis]
}
newIndex := 0
for axis := range 4 {
if newDims[axis] == 0 {
continue
}
newIndex += newCoords[axis] * newStride[axis]
}
newData[newIndex] = t.data[idx]
}
}
numDims := 4
for numDims > 1 && newShape4[numDims-1] <= 1 {
numDims--
}
newShape := make([]int, numDims)
copy(newShape, newShape4[:numDims])
return &testTensor{
dtype: t.dtype,
elementSize: t.elementSize,
data: newData,
shape: newShape,
}
}
func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
dst := t
srcTensor := src.(*testTensor)