Row Order model definitions

Functional implementation on the latest backend and caching code
Still has some debugging that needs rebasing/cleanup

one unit test fails which still needs work...
This commit is contained in:
Daniel Hiltgen 2025-01-31 08:07:15 -08:00
parent e27e4a3c1b
commit 7b3c3135de
16 changed files with 561 additions and 243 deletions

2
.gitignore vendored
View File

@ -14,3 +14,5 @@ __debug_bin*
llama/build
llama/vendor
/ollama
build/

View File

@ -251,7 +251,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
maskTensor, err := ctx.Input().FromFloatSlice(mask, batchSize, length)
if err != nil {
return nil, err
}
@ -271,12 +271,12 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
continue
}
kHeadDim := key.Dim(0)
kHeadDim := key.Dim(2)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
rowSize := key.Stride(0)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
kSrcView := key.View(ctx, rowSize*src, []int{kHeadDim * numKVHeads * len}, nil)
kDstView := key.View(ctx, rowSize*dst, []int{kHeadDim * numKVHeads * len}, nil)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
@ -284,14 +284,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vSrcView = value.View(ctx, elemSize*src, []int{vHeadDim * numKVHeads, len}, []int{int(c.Capacity) * elemSize})
vDstView = value.View(ctx, elemSize*dst, []int{vHeadDim * numKVHeads, len}, []int{int(c.Capacity) * elemSize})
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vHeadDim := value.Dim(2)
rowSize := value.Stride(0)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
vSrcView = value.View(ctx, rowSize*src, []int{vHeadDim * numKVHeads * len}, nil)
vDstView = value.View(ctx, rowSize*dst, []int{vHeadDim * numKVHeads * len}, nil)
}
ctx.Forward(
@ -430,45 +430,52 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
kHeadDim := key.Dim(0)
kHeadDim := key.Dim(2)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
cachedSize := c.curMask.Dim(0)
rowSize := key.Stride(0)
cachedSize := c.curMask.Dim(1)
// slog.Info("Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "rowSize", rowSize, "cachedSize", cachedSize)
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
[]int{cachedSize, numKVHeads, kHeadDim},
[]int{key.Stride(0), key.Stride(1)},
)
// slog.Info("Get", "key", key)
// panic("XXX")
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
elemSize := value.Stride(2)
value = value.View(ctx, elemSize*c.curCellRange.min,
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
[]int{numKVHeads, vHeadDim, cachedSize},
[]int{value.Stride(0), value.Stride(1)},
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vHeadDim := value.Dim(2)
rowSize := value.Stride(0)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
cachedSize,
[]int{cachedSize, numKVHeads, vHeadDim},
[]int{value.Stride(0), value.Stride(1)},
)
}
// TODO The mask changes from X,X to 1,X, and with the Row-order change
// the 1 becomes trailing and messes up later operations
// This isn't the right solution, but works around it...
if c.curMask.Dim(1) == 1 {
return key, value, c.curMask.Permute(ctx, 1, 0, 2, 3)
}
return key, value, c.curMask
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(0)
vHeadDim := value.Dim(0)
kHeadDim := key.Dim(2)
vHeadDim := value.Dim(2)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
batchSize := key.Dim(0)
if c.curBatchSize != batchSize {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
@ -479,29 +486,29 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), numKVHeads, kHeadDim)
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, int(c.Capacity))
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), numKVHeads, vHeadDim)
}
}
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
rowSize := c.keys[c.curLayer].Stride(0)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, []int{kHeadDim * numKVHeads * batchSize}, nil)))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
elemSize := c.values[c.curLayer].Stride(2)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, []int{vHeadDim * numKVHeads, batchSize}, []int{int(c.Capacity) * elemSize})))
} else {
rowSize := c.values[c.curLayer].Stride(2)
rowSize := c.values[c.curLayer].Stride(0)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, []int{vHeadDim * numKVHeads * batchSize}, nil)))
}
}
@ -558,14 +565,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
continue
}
kHeadDim := key.Dim(0)
kHeadDim := key.Dim(2)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
rowSize := key.Stride(0)
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size,
[]int{size, numKVHeads, kHeadDim},
[]int{key.Stride(0), key.Stride(1)},
)
roped, err := c.shiftFn(ctx, i, key, kShift)

View File

@ -31,21 +31,21 @@ func TestStore(t *testing.T) {
{
name: "FirstBatch",
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
inShape: []int{2, 3, 4},
inShape: []int{4, 3, 2},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
expectedShape: []int{2, 3, 4},
expectedShape: []int{4, 3, 2},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
{
name: "SecondBatch",
in: []float32{115, 215, 125, 225, 135, 235},
inShape: []int{2, 3, 1},
inShape: []int{1, 3, 2},
seqs: []int{0},
pos: []int32{4},
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
expectedShape: []int{2, 3, 5},
expectedShape: []int{5, 3, 2},
expectedMask: []float32{0, 0, 0, 0, 0},
},
}
@ -64,11 +64,11 @@ func TestSWA(t *testing.T) {
{
name: "SlidingWindow",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
inShape: []int{4, 1, 1},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedShape: []int{4, 1, 1},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
@ -87,21 +87,21 @@ func TestSequences(t *testing.T) {
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
inShape: []int{4, 1, 1},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedShape: []int{4, 1, 1},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
inShape: []int{2, 1, 1},
seqs: []int{0, 1},
pos: []int32{2, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedShape: []int{6, 1, 1},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
},
}
@ -122,11 +122,11 @@ func TestRemove(t *testing.T) {
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
inShape: []int{4, 1, 1},
seqs: []int{0, 0, 1, 1},
pos: []int32{0, 1, 0, 1},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedShape: []int{4, 1, 1},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
@ -142,11 +142,11 @@ func TestRemove(t *testing.T) {
{
name: "RemoveEnd",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
inShape: []int{2, 1, 1},
seqs: []int{0, 1},
pos: []int32{1, 2},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedShape: []int{6, 1, 1},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
},
}
@ -162,11 +162,11 @@ func TestRemove(t *testing.T) {
{
name: "RemoveMiddle",
in: []float32{7, 8},
inShape: []int{1, 1, 2},
inShape: []int{2, 1, 1},
seqs: []int{0, 0},
pos: []int32{1, 2},
expected: []float32{7, 8, 3, 4, 4},
expectedShape: []int{1, 1, 5},
expectedShape: []int{5, 1, 1},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
},
}
@ -187,11 +187,11 @@ func TestDefrag(t *testing.T) {
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
inShape: []int{1, 1, 16},
inShape: []int{16, 1, 1},
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
expectedShape: []int{1, 1, 16},
expectedShape: []int{16, 1, 1},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}
@ -212,11 +212,11 @@ func TestDefrag(t *testing.T) {
{
name: "Defrag",
in: []float32{17, 18, 19},
inShape: []int{1, 1, 3},
inShape: []int{3, 1, 1},
seqs: []int{0, 0, 0},
pos: []int32{16, 17, 18},
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
expectedShape: []int{1, 1, 16},
expectedShape: []int{16, 1, 1},
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}
@ -235,11 +235,11 @@ func TestCopy(t *testing.T) {
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
inShape: []int{4, 1, 1},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedShape: []int{4, 1, 1},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
},
}
@ -252,11 +252,11 @@ func TestCopy(t *testing.T) {
{
name: "Copy",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
inShape: []int{2, 1, 1},
seqs: []int{1, 1},
pos: []int32{3, 4},
expected: []float32{1, 2, 3, 4, 5, 6},
expectedShape: []int{1, 1, 6},
expectedShape: []int{6, 1, 1},
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
}
@ -365,6 +365,9 @@ func (c *testContext) MaxGraphNodes() int {
func (c *testContext) Close() {}
// TODO remove this before merging - temporary debugging aid
func (c *testContext) Abort(t ml.Tensor) {}
type testTensor struct {
dtype ml.DType
elementSize int
@ -378,7 +381,8 @@ func (t *testTensor) Dim(n int) int {
func (t *testTensor) Stride(n int) int {
stride := t.elementSize
for i := range n {
// Reverse to mimic ggml's row order impl
for i := len(t.shape) - 1; i > n; i-- {
stride *= t.shape[i]
}
@ -473,7 +477,7 @@ func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
func (t *testTensor) View(ctx ml.Context, offset int, shape []int, stride []int) ml.Tensor {
offset /= t.elementSize
var s []int
@ -481,8 +485,8 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
s = []int{shape[0]}
case 5:
s = []int{shape[0], shape[2], shape[4]}
case 3:
s = []int{shape[0], shape[1], shape[2]}
default:
panic("unsupported number of dimensions")
}
@ -496,7 +500,32 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
switch len(t.shape) {
case 2:
sh := make([]int, len(t.shape))
size := 1
for i := range sh {
sh[i] = t.shape[shape[i]]
size *= sh[i]
}
data := make([]float32, size)
dn := 0
for j := range t.shape[1] {
for i := range t.shape[0] {
offset := i*t.shape[1] + j
data[dn] = t.data[offset]
dn++
}
}
return &testTensor{
dtype: t.dtype,
elementSize: t.elementSize,
data: data,
shape: sh,
}
default:
panic("not implemented")
}
}
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {

View File

@ -97,6 +97,10 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
// TODO - the (Tensor, error) return pattern makes this impossible to
// one-line in cases where we need to pass a scalar into a function that
// requires a Tensor leading to overly verbose impls. Consider a Must* API.
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error)
@ -113,6 +117,9 @@ type Context interface {
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
// TODO remove this before merging - temporary debugging aid
Abort(Tensor) // Evaluate the graph up to this point, retrieve the data from the tensor and dump it to a json file for comparison
}
type Tensor interface {
@ -130,7 +137,7 @@ type Tensor interface {
Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
Softmax(ctx Context) Tensor
Softmax(ctx Context) Tensor // TODO axis parameter?
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
@ -145,7 +152,7 @@ type Tensor interface {
SILU(ctx Context) Tensor
Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor
View(ctx Context, offset int, shape, stride []int) Tensor
Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
@ -298,3 +305,16 @@ const (
DTypeQ40
DTypeI32
)
func (dt DType) String() string {
switch dt {
case DTypeF32:
return "float32"
case DTypeF16:
return "float16"
case DTypeI32:
return "int32"
default:
return "unknown"
}
}

View File

@ -9,12 +9,15 @@ package ggml
import "C"
import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"maps"
"math"
"os"
"runtime/debug"
"slices"
"strconv"
"strings"
@ -24,10 +27,13 @@ import (
"github.com/ollama/ollama/format"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"golang.org/x/sync/errgroup"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
)
var rev = []C.int{3, 2, 1, 0}
func devices() []*C.struct_ggml_backend_device {
ggml.OnceLoad()
ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count())
@ -65,7 +71,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
}
slog.Info(
"",
"initializing GGML backend",
"architecture", meta.KV().Architecture(),
"file_type", meta.KV().FileType(),
"name", meta.KV().String("general.name"),
@ -396,7 +402,7 @@ func (b *Backend) Config() ml.Config {
func (b *Backend) Get(name string) ml.Tensor {
if t, ok := b.tensors[name]; ok {
return &Tensor{b: b, t: t}
return &Tensor{b: b, t: t, nDims: int(C.ggml_n_dims(t))}
}
return nil
@ -529,7 +535,7 @@ func pad(length, pad C.size_t) C.size_t {
return ((length + pad - 1) / pad) * pad
}
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
func (c Context) newTensor(dtype ml.DType, rshape []int) ml.Tensor {
if c.buft == nil {
panic("set Input, Output, or Layer before creating tensors")
}
@ -550,24 +556,31 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
panic("unsupported dtype")
}
if len(shape) < 1 || shape[0] == 0 {
if len(rshape) < 1 || rshape[0] == 0 {
var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
} else if len(shape) > 4 {
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape), nDims: 1}
} else if len(rshape) > 4 {
panic("unsupported number of dimensions")
}
for _, dim := range shape {
for _, dim := range rshape {
if dim < 1 {
panic("invalid shape")
}
}
// Inverted
shape := make([]int, len(rshape))
i := len(rshape) - 1
for _, dim := range rshape {
shape[i] = dim
i--
}
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
return &Tensor{b: c.b, t: t}
return &Tensor{b: c.b, t: t, nDims: len(shape)}
}
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
@ -631,8 +644,14 @@ func (c *Context) Close() {
}
type Tensor struct {
b *Backend
t *C.struct_ggml_tensor
b *Backend
t *C.struct_ggml_tensor
// keep track of the number of dimensions
// Since we reverse the shape, GGML considers a trailing "1" dimension as not present
// and we can't actually trust the output of ggml_n_dims
nDims int
sync func()
}
@ -641,23 +660,35 @@ func (t *Tensor) LogValue() slog.Value {
slog.String("name", C.GoString(C.ggml_get_name(t.t))),
slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
slog.Any("shape", t.Shape()),
slog.Any("underlying shape", t.t.ne),
slog.Any("underlying stride", t.t.nb),
)
}
func (t *Tensor) Dim(n int) int {
return int(t.t.ne[n])
if t.nDims == 0 {
// If this hits we likely forgot to copy the dimension to the returned tensor in some operation
panic("zero dimension tensor")
}
r := rev[4-t.nDims:]
return int(t.t.ne[r[n]])
}
func (t *Tensor) Stride(n int) int {
return int(t.t.nb[n])
if t.nDims == 0 {
slog.Error("Stride", "tensor", t, "dim", n)
panic("zero dimension tensor")
}
r := rev[4-t.nDims:]
s := int(t.t.nb[r[n]])
return s
}
func (t *Tensor) Shape() []int {
shape := make([]int, C.ggml_n_dims(t.t))
shape := make([]int, t.nDims)
for i := range shape {
shape[i] = t.Dim(i)
}
return shape
}
@ -702,8 +733,9 @@ func (t *Tensor) DType() ml.DType {
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
b: t.b,
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
nDims: t.nDims,
}
}
@ -717,29 +749,38 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
b: t.b,
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
nDims: max(t.nDims, t2.(*Tensor).nDims),
}
}
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
b: t.b,
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
nDims: t.nDims,
}
}
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
b: t.b,
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
nDims: t.nDims, // TODO should this be max(t.nDims, t2.nDims)?
}
}
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
if t.t.ne[0] != t2.(*Tensor).t.ne[0] {
slog.Error("incorrect tensor shapes for Mulmat", "t", t, "t2", t2)
panic("malformed tensors passed to Mulmat")
}
r := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
return &Tensor{
b: t.b,
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
b: t.b,
t: r,
nDims: max(t.nDims, t2.(*Tensor).nDims),
}
}
@ -748,13 +789,14 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
b: t.b,
t: mul,
b: t.b,
t: mul,
nDims: max(t.nDims, t2.(*Tensor).nDims),
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)), nDims: t.nDims}).Mul(ctx, w)
if b != nil {
tt = tt.Add(ctx, b)
}
@ -763,17 +805,28 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
}
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps)), nDims: t.nDims}).Mul(ctx, w)
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 {
panic("expected 4 dimensions")
}
var r *C.struct_ggml_tensor
switch t.nDims {
case 1:
r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3]))
case 2:
r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[1]), C.int(shape[0]), C.int(shape[2]), C.int(shape[3]))
case 3:
r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[2]), C.int(shape[1]), C.int(shape[0]), C.int(shape[3]))
default:
r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[3]), C.int(shape[2]), C.int(shape[1]), C.int(shape[0]))
}
return &Tensor{
b: t.b,
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
b: t.b,
t: r,
nDims: t.nDims,
}
}
@ -781,48 +834,132 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 {
panic("expected 4 dimensions")
}
return &Tensor{
b: t.b,
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
rshape := []C.int{0, 1, 2, 3}
switch t.nDims {
case 2:
// TODO make sure this isn't wonky...
rshape[0] = rev[2:][shape[1]]
rshape[1] = rev[2:][shape[0]]
case 3:
// TODO has to be a better way...
rshape[0] = C.int(shape[0])
rshape[1] = C.int(shape[1])
rshape[2] = C.int(shape[2])
switch shape[0]*100 + shape[1]*10 + shape[2] {
case 21:
rshape[0], rshape[1], rshape[2] = 1, 0, 2
case 102:
rshape[0], rshape[1], rshape[2] = 0, 2, 1
}
case 4:
// TODO has to be a better way...
rshape[0] = C.int(shape[0])
rshape[1] = C.int(shape[1])
rshape[2] = C.int(shape[2])
rshape[3] = C.int(shape[3])
switch shape[0]*1000 + shape[1]*100 + shape[2]*10 + shape[3] {
case 132:
rshape[0], rshape[1], rshape[2], rshape[3] = 1, 0, 2, 3
case 231:
rshape[0], rshape[1], rshape[2], rshape[3] = 1, 2, 0, 3
case 312:
rshape[0], rshape[1], rshape[2], rshape[3] = 2, 0, 1, 3
case 321:
rshape[0], rshape[1], rshape[2], rshape[3] = 2, 1, 0, 3
case 1023:
rshape[0], rshape[1], rshape[2], rshape[3] = 0, 1, 3, 2
case 1203:
rshape[0], rshape[1], rshape[2], rshape[3] = 0, 2, 3, 1
case 1302:
rshape[0], rshape[1], rshape[2], rshape[3] = 2, 0, 3, 1
case 1320:
rshape[0], rshape[1], rshape[2], rshape[3] = 2, 1, 3, 0
case 2013:
rshape[0], rshape[1], rshape[2], rshape[3] = 0, 3, 1, 2
case 2031:
rshape[0], rshape[1], rshape[2], rshape[3] = 1, 3, 0, 2
case 2103:
rshape[0], rshape[1], rshape[2], rshape[3] = 0, 3, 2, 1
case 2130:
rshape[0], rshape[1], rshape[2], rshape[3] = 1, 3, 2, 0
case 3021:
rshape[0], rshape[1], rshape[2], rshape[3] = 3, 1, 0, 2
case 3102:
rshape[0], rshape[1], rshape[2], rshape[3] = 3, 0, 2, 1
}
}
r := &Tensor{
b: t.b,
t: C.ggml_permute(ctx.(*Context).ctx, t.t, rshape[0], rshape[1], rshape[2], rshape[3]),
nDims: t.nDims,
}
return r
}
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
b: t.b,
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
nDims: t.nDims,
}
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
r := &Tensor{
b: t.b,
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
nDims: t.nDims,
}
return r
}
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
// GGML does not handle -1 natively
for i, sh := range shape {
if sh == -1 {
totalElems := 1
for d := range t.nDims {
totalElems *= int(t.t.ne[d])
}
otherElems := 1
for _, osh := range shape {
if osh != -1 {
otherElems *= osh
}
}
if otherElems > totalElems {
slog.Error("Invalid request", "req", shape, "actual", t.Shape(), "totalElems", totalElems, "otherElems", otherElems)
panic("impossible -1 shape request")
}
shape[i] = int(float64(totalElems) / float64(otherElems))
break
}
}
switch len(shape) {
case 1:
return &Tensor{
b: t.b,
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
b: t.b,
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
nDims: len(shape),
}
case 2:
return &Tensor{
b: t.b,
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
b: t.b,
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[1]), C.int64_t(shape[0])),
nDims: len(shape),
}
case 3:
return &Tensor{
b: t.b,
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
b: t.b,
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0])),
nDims: len(shape),
}
case 4:
return &Tensor{
b: t.b,
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
b: t.b,
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[3]), C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0])),
nDims: len(shape),
}
default:
panic("unsupported number of dimensions")
@ -831,22 +968,25 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
b: t.b,
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
nDims: t.nDims,
}
}
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
b: t.b,
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
nDims: t.nDims,
}
}
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
b: t.b,
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
nDims: t.nDims,
}
}
@ -856,41 +996,50 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
b: t.b,
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
b: t.b,
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[3]), C.int(shape[2]), C.int(shape[1]), C.int(shape[0])),
nDims: t.nDims, // TODO is this right?
}
}
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
func (t *Tensor) View(ctx ml.Context, offset int, shape, stride []int) ml.Tensor {
if len(stride)+1 != len(shape) {
panic(fmt.Sprintf("malformed view request: shape=%v stride=%v", shape, stride))
}
switch len(shape) {
case 1:
return &Tensor{
b: t.b,
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
nDims: 1,
}
case 2:
return &Tensor{
b: t.b,
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[1]), C.int64_t(shape[0]),
C.size_t(stride[0]),
C.size_t(offset)),
nDims: 2,
}
case 3:
return &Tensor{
b: t.b,
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]),
C.size_t(shape[1]),
C.size_t(offset)),
}
case 5:
return &Tensor{
b: t.b,
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
C.size_t(shape[1]), C.size_t(shape[3]),
C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0]),
C.size_t(stride[1]), C.size_t(stride[0]),
C.size_t(offset)),
nDims: 3,
}
case 7:
case 4:
return &Tensor{
b: t.b,
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
C.int64_t(shape[3]), C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0]),
C.size_t(stride[2]), C.size_t(stride[1]), C.size_t(stride[0]),
C.size_t(offset)),
nDims: 4,
}
default:
panic("unsupported number of dimensions")
@ -906,14 +1055,13 @@ const (
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
ropeFactors = &Tensor{b: t.b, nDims: 0}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
@ -928,34 +1076,39 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
32., // YaRN beta_fast
1., // YaRN beta_slow
),
nDims: t.nDims,
}
}
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
nDims: t.nDims,
}
}
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
nDims: t.nDims,
}
}
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
b: t.b,
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
nDims: t.nDims,
}
}
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
b: t.b,
t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
nDims: t.nDims,
}
}
@ -970,7 +1123,7 @@ func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) m
panic("unsupported number of dimensions")
}
return &Tensor{b: t.b, t: tt}
return &Tensor{b: t.b, t: tt, nDims: t.nDims}
}
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
@ -979,23 +1132,60 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
kqMask = mask.(*Tensor).t
}
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
query := t.Permute(ctx, 1, 0, 2, 3)
key = key.Permute(ctx, 1, 0, 2, 3)
if t.b.flashAttention {
value = value.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 0, 2, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
return &Tensor{b: t.b, t: kqv}
return &Tensor{b: t.b, t: kqv, nDims: t.nDims}
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
nDims: t.nDims,
}
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
return kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
}
}
// TODO remove this before merging - temporary debugging aid
func (c *Context) Abort(t ml.Tensor) {
// Hack to make sure we're f32, otherwise r.Floats will fail due to short read
if t.(*Tensor).t._type != C.GGML_TYPE_F32 {
t.(*Tensor).t = C.ggml_cast(c.ctx, t.(*Tensor).t, C.GGML_TYPE_F32)
}
c.Forward(t)
c.Compute(t)
f32 := t.Floats()
// Convert [-]Inf to serializable values
for i, v := range f32 {
if v > math.MaxFloat32 {
f32[i] = math.MaxFloat32
}
if v < -math.SmallestNonzeroFloat32 {
f32[i] = -math.MaxFloat32
}
}
debug.PrintStack()
filename := "ggml.json"
slog.Info("Writing tensors to", "filename", filename)
f, err := os.Create(filename)
if err != nil {
panic(err)
}
defer f.Close()
encoder := json.NewEncoder(f)
err = encoder.Encode(f32)
if err != nil {
panic(err)
}
os.Exit(1)
}

View File

@ -12,9 +12,9 @@ import (
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - query: Query tensor (Q) with shape [seq_len_q, heads, d_k]
// - key: Key tensor (K) with shape [seq_len_k, kv_heads, d_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [seq_len_k, kv_heads, d_v], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
@ -23,16 +23,16 @@ import (
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
if query.Dim(2) != key.Dim(2) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(2), key.Dim(2)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
if key.Dim(0) != value.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(0), value.Dim(0)))
}
if cache != nil {
@ -52,8 +52,8 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
query = query.Permute(ctx, 1, 0, 2, 3)
key = key.Permute(ctx, 1, 0, 2, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
@ -65,6 +65,6 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
return kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
}
}

62
model/README.md Normal file
View File

@ -0,0 +1,62 @@
# Ollama Models
Large Language Models in Ollama are defined in the Go programming language within this directory.
## Model Implementation Guide
Ollama supports multiple backends, and provides an astracted interface for model implementers. [Backend API ](../ml/backend.go)
This API is designed to be similar to other popular python libraries such as
PyTorch, with row-major tensors and a forward function that takes a sequence of
inputs.
Use an existing model as an initial reference, such as [llama](./models/llama/)
Cheatsheet:
<table>
<tr>
<td><b>PyTorch</b></td>
<td><b>Ollama</b></td>
</tr>
<tr>
<td>torch.zeros((2, 2))</td>
<td>ctx.Zeros(ml.DTypeF32, 2, 2)</td>
</tr>
<tr>
<td>tensor.view((2, 2))</td>
<td>t.Reshape(ctx, 2, 2)</td>
</tr>
<tr>
<td>torch.permute(t1, (1, 2, 3))</td>
<td>t1.Permute(ctx, 1, 2, 3)</td>
</tr>
<tr>
<td>torch.add(t1, t2)</td>
<td>t1.Add(ctx, t2)</td>
</tr>
<tr>
<td>
```python
class Attention(nn.Module):
def __call__(self, ...):
...
```
</td>
<td>
```go
func (sa *SelfAttention) Forward(ctx ml.Context,
hiddenState, positionIDs ml.Tensor,
cache kvcache.Cache,
opts *Options) ml.Tensor {
...
}
```
</td>
</tr>
</table>

View File

@ -306,3 +306,14 @@ func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
return t, nil
}
func ArangeF32(start, end, step float32) []float32 {
if step == 0 || start >= end {
return nil
}
var res []float32
for i := float32(start); i < end; i += step {
res = append(res, i)
}
return res
}

View File

@ -77,11 +77,11 @@ type SelfAttention struct {
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
batchSize := hiddenState.Dim(0)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
if opts.largeModelScaling {
@ -91,17 +91,17 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
q = q.Permute(ctx, 1, 0, 2, 3)
k = k.Permute(ctx, 1, 0, 2, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
@ -115,8 +115,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
kqv = kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, batchSize, opts.attnValLen*opts.numHeads)
return sa.Output.Forward(ctx, kqv)
}

View File

@ -35,15 +35,15 @@ type MultiModalProjector struct {
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
l := visionOutputs.Dim(0)
l := visionOutputs.Dim(1)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
patchesPerImage := imageSize / patchSize
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
visionOutputs = visionOutputs.Reshape(ctx, l, patchesPerImage, patchesPerImage)
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
visionOutputs = visionOutputs.Reshape(ctx, l, visionOutputs.Dim(2)*visionOutputs.Dim(1))
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
@ -98,9 +98,9 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
)
if err != nil {
return nil, err
@ -121,13 +121,13 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
inputMultimodal := inp.Multimodal.(ml.Tensor)
result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(0) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(0)-1)...)
result = append(result,
input.Input{Token: 256000}, // <end_of_image>

View File

@ -85,7 +85,7 @@ type TextSelfAttention struct {
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
batchSize := hiddenState.Dim(0)
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase
@ -94,7 +94,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
@ -105,16 +105,16 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen)
scaleFactor := 1.0
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
kqv = kqv.Reshape(ctx, batchSize, opts.attnValLen*opts.numHeads)
return sa.Output.Forward(ctx, kqv)
}
@ -179,9 +179,9 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
var except []int
for _, image := range opts.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(0), []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)}, nil)))
for i := range visionOutputs.Dim(1) {
for i := range visionOutputs.Dim(0) {
except = append(except, image.Index+i)
}
}

View File

@ -7,8 +7,6 @@ import (
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
@ -23,12 +21,12 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
query = query.Reshape(ctx, query.Dim(0), opts.numHeads, headDim)
key = key.Reshape(ctx, key.Dim(0), opts.numHeads, headDim)
value = value.Reshape(ctx, value.Dim(0), opts.numHeads, headDim)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
attention = attention.Reshape(ctx, attention.Dim(0), opts.hiddenSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
@ -88,7 +86,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPatches)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
positions := make([]int32, numPatches)

View File

@ -74,24 +74,24 @@ type SelfAttention struct {
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
batchSize := hiddenState.Dim(0) // TODO Consider renaming "L" as this is the sequence length, not batch size
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.Reshape(ctx, batchSize, opts.numHeads, -1)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.Reshape(ctx, batchSize, opts.numKVHeads, -1)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
v = v.Reshape(ctx, batchSize, opts.numKVHeads, -1)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
kqv = kqv.Reshape(ctx, batchSize, -1)
return sa.Output.Forward(ctx, kqv)
}

View File

@ -78,10 +78,10 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
m.ImageProcessor.numChannels,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
)
if err != nil {
return nil, err

View File

@ -18,24 +18,24 @@ type TextSelfAttention struct {
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
batchSize := hiddenState.Dim(0)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.Reshape(ctx, batchSize, opts.numHeads, headDim)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.Reshape(ctx, batchSize, opts.numKVHeads, headDim)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
value = value.Reshape(ctx, batchSize, opts.numKVHeads, headDim)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
attention = attention.Reshape(ctx, batchSize, opts.hiddenSize)
return sa.Output.Forward(ctx, attention)
}
@ -99,23 +99,23 @@ type TextCrossAttention struct {
}
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
batchSize := hiddenState.Dim(0)
headDim := opts.hiddenSize / opts.numHeads
query := ca.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.Reshape(ctx, batchSize, opts.numHeads, headDim)
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value ml.Tensor
if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
numVisionTokens, numTiles := crossAttentionStates.Dim(2), crossAttentionStates.Dim(1)
key = ca.Key.Forward(ctx, crossAttentionStates)
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
key = key.Reshape(ctx, numVisionTokens*numTiles, opts.numKVHeads, headDim)
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
value = ca.Value.Forward(ctx, crossAttentionStates)
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
value = value.Reshape(ctx, numVisionTokens*numTiles, opts.numKVHeads, headDim)
cache.Put(ctx, key, value)
}
@ -124,8 +124,8 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
query = query.Permute(ctx, 1, 0, 2, 3)
key = key.Permute(ctx, 1, 0, 2, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
@ -134,8 +134,8 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
attention := kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, batchSize, opts.hiddenSize)
return ca.Output.Forward(ctx, attention)
}

View File

@ -23,25 +23,25 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
query = query.Reshape(ctx, batchSize, query.Dim(1), opts.numHeads, headDim)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
key = key.Reshape(ctx, batchSize, key.Dim(1), opts.numHeads, headDim)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
value = value.Reshape(ctx, batchSize, value.Dim(1), opts.numHeads, headDim)
value = value.Permute(ctx, 0, 2, 3, 1).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
attention = attention.Reshape(ctx, batchSize, opts.numHeads, attention.Dim(2), headDim)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
attention = attention.Reshape(ctx, batchSize, attention.Dim(1), opts.hiddenSize)
hiddenState = sa.Output.Forward(ctx, attention)
if sa.Gate != nil {
@ -99,7 +99,7 @@ func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermedi
var intermediateHiddenStates []ml.Tensor
for i, layer := range e.Layers {
if slices.Contains(intermediateLayersIndices, uint32(i)) {
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append(hiddenState.Shape(), 1)...))
}
hiddenState = layer.Forward(ctx, hiddenState, opts)
@ -115,7 +115,7 @@ type PrecomputedAspectRatioEmbedding struct {
func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
embeddings = embeddings.Reshape(ctx, opts.numTiles, 1, opts.hiddenSize)
if e.Gate != nil {
embeddings = embeddings.Mul(ctx, e.Gate)
}
@ -140,7 +140,7 @@ func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, posi
hiddenState = hiddenState.Add(ctx, positionEmbedding)
tilePositionEmbedding := e.TilePositionEmbedding.Forward(ctx, aspectRatioIDs)
tilePositionEmbedding = tilePositionEmbedding.Reshape(ctx, opts.hiddenSize, numPositions, opts.numTiles)
tilePositionEmbedding = tilePositionEmbedding.Reshape(ctx, opts.numTiles, numPositions, opts.hiddenSize)
if e.TilePositionEmbeddingGate != nil {
tilePositionEmbedding = tilePositionEmbedding.Mul(ctx, e.TilePositionEmbeddingGate)
}
@ -181,8 +181,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
}
hiddenState := m.PatchEmbeddings.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize, m.numTiles)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = hiddenState.Reshape(ctx, m.numTiles, m.hiddenSize, numPatches)
hiddenState = hiddenState.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1)
@ -193,23 +193,23 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
numPaddingPatches := 8 - (hiddenState.Dim(1)%8)%8
hiddenState = hiddenState.Pad(ctx, 0, numPaddingPatches, 0, 0)
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), hiddenState.Dim(1)*hiddenState.Dim(2), batchSize)
hiddenState = hiddenState.Reshape(ctx, batchSize, hiddenState.Dim(1)*hiddenState.Dim(0), hiddenState.Dim(2))
hiddenState, intermediateHiddenStates := m.Transformer.Forward(ctx, hiddenState, m.intermediateLayersIndices, m.VisionModelOptions)
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
hiddenState = hiddenState.Reshape(ctx, batchSize, m.numTiles, numPositions+numPaddingPatches, m.hiddenSize)
hiddenState = m.PostTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, m.numTiles*(numPositions+numPaddingPatches), batchSize)
hiddenState = hiddenState.Reshape(ctx, batchSize, m.numTiles*(numPositions+numPaddingPatches), m.hiddenSize)
hiddenState, _ = m.GlobalTransformer.Forward(ctx, hiddenState, nil, m.VisionModelOptions)
hiddenStates := intermediateHiddenStates[0].Stack(ctx, 0, intermediateHiddenStates[1:]...)
hiddenStates = hiddenStates.Reshape(ctx, len(intermediateHiddenStates)*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
hiddenStates = hiddenStates.Unpad(ctx, 0, numPaddingPatches, 0, 0)
hiddenStates = hiddenStates.Reshape(ctx, batchSize, m.numTiles, numPositions+numPaddingPatches, len(intermediateHiddenStates)*m.hiddenSize)
hiddenStates = hiddenStates.Unpad(ctx, 0, 0, numPaddingPatches, 0)
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
hiddenState = hiddenState.Unpad(ctx, 0, numPaddingPatches, 0, 0)
hiddenState = hiddenState.Reshape(ctx, batchSize, m.numTiles, numPositions+numPaddingPatches, m.hiddenSize)
hiddenState = hiddenState.Unpad(ctx, 0, 0, numPaddingPatches, 0)
return hiddenState.Concat(ctx, hiddenStates, 0)
}