mirror of
https://github.com/ollama/ollama.git
synced 2025-12-04 00:42:18 +01:00
kvcache: Run tests both with and without PermutedV
The causal cache can store data differently depending on what is best for the backend. We should run tests both ways.
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package kvcache
|
package kvcache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -20,8 +21,17 @@ type testCase struct {
|
|||||||
expectedMask []float32
|
expectedMask []float32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
|
||||||
|
t.Helper()
|
||||||
|
for _, permuted := range []bool{false, true} {
|
||||||
|
t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
|
||||||
|
fn(t, &testBackend{permutedV: permuted})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStore(t *testing.T) {
|
func TestStore(t *testing.T) {
|
||||||
backend := &testBackend{}
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
@@ -51,10 +61,11 @@ func TestStore(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSWA(t *testing.T) {
|
func TestSWA(t *testing.T) {
|
||||||
backend := &testBackend{}
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
cache := NewSWACache(1, nil)
|
cache := NewSWACache(1, nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
@@ -94,10 +105,11 @@ func TestSWA(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSWASeparateBatches(t *testing.T) {
|
func TestSWASeparateBatches(t *testing.T) {
|
||||||
backend := &testBackend{}
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
cache := NewSWACache(1, nil)
|
cache := NewSWACache(1, nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
@@ -174,10 +186,11 @@ func TestSWASeparateBatches(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSWAMem(t *testing.T) {
|
func TestSWAMem(t *testing.T) {
|
||||||
backend := &testBackend{}
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
cache := NewSWAMemCache(1, 3, nil)
|
cache := NewSWAMemCache(1, 3, nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
@@ -217,19 +230,20 @@ func TestSWAMem(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestChunkedAttention(t *testing.T) {
|
func TestChunkedAttention(t *testing.T) {
|
||||||
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
cache := NewChunkedAttentionCache(2, nil)
|
cache := NewChunkedAttentionCache(2, nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
var b testBackend
|
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||||
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
|
|
||||||
|
|
||||||
x := float32(math.Inf(-1))
|
x := float32(math.Inf(-1))
|
||||||
|
|
||||||
testCache(
|
testCache(
|
||||||
t, &b, cache,
|
t, backend, cache,
|
||||||
[]testCase{
|
[]testCase{
|
||||||
{
|
{
|
||||||
name: "FirstBatch",
|
name: "FirstBatch",
|
||||||
@@ -275,10 +289,11 @@ func TestChunkedAttention(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSequences(t *testing.T) {
|
func TestSequences(t *testing.T) {
|
||||||
backend := &testBackend{}
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
cache := NewCausalCache(nil)
|
cache := NewCausalCache(nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
@@ -308,10 +323,11 @@ func TestSequences(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemove(t *testing.T) {
|
func TestRemove(t *testing.T) {
|
||||||
backend := &testBackend{}
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.Add(ctx, shift), nil
|
return key.Add(ctx, shift), nil
|
||||||
})
|
})
|
||||||
@@ -386,10 +402,11 @@ func TestRemove(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCopy(t *testing.T) {
|
func TestCopy(t *testing.T) {
|
||||||
backend := &testBackend{}
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
|
|
||||||
@@ -426,6 +443,7 @@ func TestCopy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
testCache(t, backend, cache, tests)
|
testCache(t, backend, cache, tests)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
|
||||||
@@ -463,7 +481,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCanResume(t *testing.T) {
|
func TestCanResume(t *testing.T) {
|
||||||
backend := &testBackend{}
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
windowSize := int32(4)
|
windowSize := int32(4)
|
||||||
cache := NewSWACache(windowSize, nil)
|
cache := NewSWACache(windowSize, nil)
|
||||||
defer cache.Close()
|
defer cache.Close()
|
||||||
@@ -534,10 +552,11 @@ func TestCanResume(t *testing.T) {
|
|||||||
if !cache.CanResume(0, 5) {
|
if !cache.CanResume(0, 5) {
|
||||||
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCanResumeSWAMem(t *testing.T) {
|
func TestCanResumeSWAMem(t *testing.T) {
|
||||||
backend := &testBackend{}
|
runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
|
||||||
windowSize := int32(4)
|
windowSize := int32(4)
|
||||||
memSize := int32(5)
|
memSize := int32(5)
|
||||||
cache := NewSWAMemCache(windowSize, memSize, nil)
|
cache := NewSWAMemCache(windowSize, memSize, nil)
|
||||||
@@ -598,10 +617,12 @@ func TestCanResumeSWAMem(t *testing.T) {
|
|||||||
if !cache.CanResume(0, 7) {
|
if !cache.CanResume(0, 7) {
|
||||||
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type testBackend struct {
|
type testBackend struct {
|
||||||
ml.Backend
|
ml.Backend
|
||||||
|
permutedV bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *testBackend) NewContext() ml.Context {
|
func (b *testBackend) NewContext() ml.Context {
|
||||||
@@ -612,6 +633,10 @@ func (b *testBackend) NewContextSize(int) ml.Context {
|
|||||||
return &testContext{}
|
return &testContext{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *testBackend) CacheConfig() ml.CacheConfig {
|
||||||
|
return ml.CacheConfig{PermutedV: b.permutedV}
|
||||||
|
}
|
||||||
|
|
||||||
type testContext struct {
|
type testContext struct {
|
||||||
ml.Context
|
ml.Context
|
||||||
}
|
}
|
||||||
@@ -766,6 +791,102 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
return view
|
return view
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
||||||
|
if len(t.shape) > 4 || len(order) > 4 {
|
||||||
|
panic("permute only supports up to 4 dimensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(order) != len(t.shape) && len(order) != 4 {
|
||||||
|
panic("invalid number of dimensions for permute")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_permute expects 4 axes, so fill in any missing dimensions.
|
||||||
|
orderFull := append(make([]int, 0, 4), order...)
|
||||||
|
for len(orderFull) < 4 {
|
||||||
|
orderFull = append(orderFull, len(orderFull))
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := [4]bool{}
|
||||||
|
|
||||||
|
shape4 := [4]int{1, 1, 1, 1}
|
||||||
|
for i := 0; i < len(t.shape) && i < 4; i++ {
|
||||||
|
shape4[i] = t.shape[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
newShape4 := [4]int{1, 1, 1, 1}
|
||||||
|
for axis := range 4 {
|
||||||
|
dst := orderFull[axis]
|
||||||
|
if dst < 0 || dst >= 4 {
|
||||||
|
panic("invalid axis for permute")
|
||||||
|
}
|
||||||
|
if seen[dst] {
|
||||||
|
panic("duplicate axis for permute")
|
||||||
|
}
|
||||||
|
seen[dst] = true
|
||||||
|
newShape4[dst] = shape4[axis]
|
||||||
|
}
|
||||||
|
|
||||||
|
total := len(t.data)
|
||||||
|
newData := make([]float32, total)
|
||||||
|
|
||||||
|
if total > 0 {
|
||||||
|
oldDims := shape4
|
||||||
|
newDims := newShape4
|
||||||
|
|
||||||
|
oldStride := [4]int{1, 1, 1, 1}
|
||||||
|
newStride := [4]int{1, 1, 1, 1}
|
||||||
|
for i := 1; i < 4; i++ {
|
||||||
|
oldStride[i] = oldStride[i-1] * oldDims[i-1]
|
||||||
|
newStride[i] = newStride[i-1] * newDims[i-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
var coords [4]int
|
||||||
|
var newCoords [4]int
|
||||||
|
|
||||||
|
for idx := range total {
|
||||||
|
remainder := idx
|
||||||
|
for axis := range 4 {
|
||||||
|
dim := oldDims[axis]
|
||||||
|
if dim == 0 {
|
||||||
|
coords[axis] = 0
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
coords[axis] = remainder % dim
|
||||||
|
remainder /= dim
|
||||||
|
}
|
||||||
|
|
||||||
|
for axis := range 4 {
|
||||||
|
newCoords[orderFull[axis]] = coords[axis]
|
||||||
|
}
|
||||||
|
|
||||||
|
newIndex := 0
|
||||||
|
for axis := range 4 {
|
||||||
|
if newDims[axis] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newIndex += newCoords[axis] * newStride[axis]
|
||||||
|
}
|
||||||
|
|
||||||
|
newData[newIndex] = t.data[idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
numDims := 4
|
||||||
|
for numDims > 1 && newShape4[numDims-1] <= 1 {
|
||||||
|
numDims--
|
||||||
|
}
|
||||||
|
|
||||||
|
newShape := make([]int, numDims)
|
||||||
|
copy(newShape, newShape4[:numDims])
|
||||||
|
|
||||||
|
return &testTensor{
|
||||||
|
dtype: t.dtype,
|
||||||
|
elementSize: t.elementSize,
|
||||||
|
data: newData,
|
||||||
|
shape: newShape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
|
||||||
dst := t
|
dst := t
|
||||||
srcTensor := src.(*testTensor)
|
srcTensor := src.(*testTensor)
|
||||||
|
|||||||
Reference in New Issue
Block a user