diff --git a/.gitignore b/.gitignore index 3a2af0bd1..4ab4875f3 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ __debug_bin* llama/build llama/vendor /ollama +build/ + diff --git a/kvcache/causal.go b/kvcache/causal.go index edf6666da..17c4a83a1 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -251,7 +251,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { mask[i] = float32(math.Inf(-1)) } - maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize) + maskTensor, err := ctx.Input().FromFloatSlice(mask, batchSize, length) if err != nil { return nil, err } @@ -271,12 +271,12 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { continue } - kHeadDim := key.Dim(0) + kHeadDim := key.Dim(2) numKVHeads := key.Dim(1) - rowSize := key.Stride(2) + rowSize := key.Stride(0) - kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) - kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len) + kSrcView := key.View(ctx, rowSize*src, []int{kHeadDim * numKVHeads * len}, nil) + kDstView := key.View(ctx, rowSize*dst, []int{kHeadDim * numKVHeads * len}, nil) value := c.values[i] var vSrcView, vDstView ml.Tensor @@ -284,14 +284,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { vHeadDim := value.Dim(1) elemSize := value.Stride(0) - vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) - vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads) + vSrcView = value.View(ctx, elemSize*src, []int{vHeadDim * numKVHeads, len}, []int{int(c.Capacity) * elemSize}) + vDstView = value.View(ctx, elemSize*dst, []int{vHeadDim * numKVHeads, len}, []int{int(c.Capacity) * elemSize}) } else { - vHeadDim := value.Dim(0) - rowSize := value.Stride(2) + vHeadDim := value.Dim(2) + rowSize := value.Stride(0) - vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len) - vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len) + vSrcView = value.View(ctx, rowSize*src, []int{vHeadDim * numKVHeads * len}, nil) + vDstView = value.View(ctx, rowSize*dst, []int{vHeadDim * numKVHeads * len}, nil) } ctx.Forward( @@ -430,45 +430,52 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { key := c.keys[c.curLayer] value := c.values[c.curLayer] - kHeadDim := key.Dim(0) + kHeadDim := key.Dim(2) numKVHeads := key.Dim(1) - rowSize := key.Stride(2) - cachedSize := c.curMask.Dim(0) + rowSize := key.Stride(0) + cachedSize := c.curMask.Dim(1) + // slog.Info("Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "rowSize", rowSize, "cachedSize", cachedSize) key = key.View(ctx, rowSize*c.curCellRange.min, - kHeadDim, key.Stride(1), - numKVHeads, key.Stride(2), - cachedSize, + []int{cachedSize, numKVHeads, kHeadDim}, + []int{key.Stride(0), key.Stride(1)}, ) + // slog.Info("Get", "key", key) + // panic("XXX") if c.config.PermutedV { vHeadDim := value.Dim(1) - elemSize := value.Stride(0) + elemSize := value.Stride(2) value = value.View(ctx, elemSize*c.curCellRange.min, - cachedSize, value.Stride(1), - vHeadDim, value.Stride(2), - numKVHeads, + []int{numKVHeads, vHeadDim, cachedSize}, + []int{value.Stride(0), value.Stride(1)}, ) } else { - vHeadDim := value.Dim(0) - rowSize := value.Stride(2) + vHeadDim := value.Dim(2) + rowSize := value.Stride(0) value = value.View(ctx, rowSize*c.curCellRange.min, - vHeadDim, value.Stride(1), - numKVHeads, value.Stride(2), - cachedSize, + []int{cachedSize, numKVHeads, vHeadDim}, + []int{value.Stride(0), value.Stride(1)}, ) } + // TODO The mask changes from X,X to 1,X, and with the Row-order change + // the 1 becomes trailing and messes up later operations + // This isn't the right solution, but works around it... + if c.curMask.Dim(1) == 1 { + return key, value, c.curMask.Permute(ctx, 1, 0, 2, 3) + } + return key, value, c.curMask } func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { - kHeadDim := key.Dim(0) - vHeadDim := value.Dim(0) + kHeadDim := key.Dim(2) + vHeadDim := value.Dim(2) numKVHeads := key.Dim(1) - batchSize := key.Dim(2) + batchSize := key.Dim(0) if c.curBatchSize != batchSize { panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize)) @@ -479,29 +486,29 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { } if _, ok := c.keys[c.curLayer]; !ok { - c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) + c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), numKVHeads, kHeadDim) } if _, ok := c.values[c.curLayer]; !ok { if c.config.PermutedV { - c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, int(c.Capacity)) } else { - c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity)) + c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), numKVHeads, vHeadDim) } } - rowSize := c.keys[c.curLayer].Stride(2) - ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize))) + rowSize := c.keys[c.curLayer].Stride(0) + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, []int{kHeadDim * numKVHeads * batchSize}, nil))) if c.config.PermutedV { - elemSize := c.values[c.curLayer].Stride(0) + elemSize := c.values[c.curLayer].Stride(2) value = value.Permute(ctx, 1, 2, 0, 3) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads))) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, []int{vHeadDim * numKVHeads, batchSize}, []int{int(c.Capacity) * elemSize}))) } else { - rowSize := c.values[c.curLayer].Stride(2) + rowSize := c.values[c.curLayer].Stride(0) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize))) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, []int{vHeadDim * numKVHeads * batchSize}, nil))) } } @@ -558,14 +565,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { continue } - kHeadDim := key.Dim(0) + kHeadDim := key.Dim(2) numKVHeads := key.Dim(1) - rowSize := key.Stride(2) + rowSize := key.Stride(0) key = key.View(ctx, rowSize*seqRange.min, - kHeadDim, key.Stride(1), - numKVHeads, key.Stride(2), - size, + []int{size, numKVHeads, kHeadDim}, + []int{key.Stride(0), key.Stride(1)}, ) roped, err := c.shiftFn(ctx, i, key, kShift) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 56d85ceb6..87cc60558 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -31,21 +31,21 @@ func TestStore(t *testing.T) { { name: "FirstBatch", in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, - inShape: []int{2, 3, 4}, + inShape: []int{4, 3, 2}, seqs: []int{0, 0, 0, 0}, pos: []int32{0, 1, 2, 3}, expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, - expectedShape: []int{2, 3, 4}, + expectedShape: []int{4, 3, 2}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, }, { name: "SecondBatch", in: []float32{115, 215, 125, 225, 135, 235}, - inShape: []int{2, 3, 1}, + inShape: []int{1, 3, 2}, seqs: []int{0}, pos: []int32{4}, expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235}, - expectedShape: []int{2, 3, 5}, + expectedShape: []int{5, 3, 2}, expectedMask: []float32{0, 0, 0, 0, 0}, }, } @@ -64,11 +64,11 @@ func TestSWA(t *testing.T) { { name: "SlidingWindow", in: []float32{1, 2, 3, 4}, - inShape: []int{1, 1, 4}, + inShape: []int{4, 1, 1}, seqs: []int{0, 0, 0, 0}, pos: []int32{0, 1, 2, 3}, expected: []float32{1, 2, 3, 4}, - expectedShape: []int{1, 1, 4}, + expectedShape: []int{4, 1, 1}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, }, } @@ -87,21 +87,21 @@ func TestSequences(t *testing.T) { { name: "FirstBatch", in: []float32{1, 2, 3, 4}, - inShape: []int{1, 1, 4}, + inShape: []int{4, 1, 1}, seqs: []int{0, 0, 1, 1}, pos: []int32{0, 1, 0, 1}, expected: []float32{1, 2, 3, 4}, - expectedShape: []int{1, 1, 4}, + expectedShape: []int{4, 1, 1}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, }, { name: "SecondBatch", in: []float32{5, 6}, - inShape: []int{1, 1, 2}, + inShape: []int{2, 1, 1}, seqs: []int{0, 1}, pos: []int32{2, 2}, expected: []float32{1, 2, 3, 4, 5, 6}, - expectedShape: []int{1, 1, 6}, + expectedShape: []int{6, 1, 1}, expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, }, } @@ -122,11 +122,11 @@ func TestRemove(t *testing.T) { { name: "FirstBatch", in: []float32{1, 2, 3, 4}, - inShape: []int{1, 1, 4}, + inShape: []int{4, 1, 1}, seqs: []int{0, 0, 1, 1}, pos: []int32{0, 1, 0, 1}, expected: []float32{1, 2, 3, 4}, - expectedShape: []int{1, 1, 4}, + expectedShape: []int{4, 1, 1}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, }, } @@ -142,11 +142,11 @@ func TestRemove(t *testing.T) { { name: "RemoveEnd", in: []float32{5, 6}, - inShape: []int{1, 1, 2}, + inShape: []int{2, 1, 1}, seqs: []int{0, 1}, pos: []int32{1, 2}, expected: []float32{1, 2, 3, 4, 5, 6}, - expectedShape: []int{1, 1, 6}, + expectedShape: []int{6, 1, 1}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, }, } @@ -162,11 +162,11 @@ func TestRemove(t *testing.T) { { name: "RemoveMiddle", in: []float32{7, 8}, - inShape: []int{1, 1, 2}, + inShape: []int{2, 1, 1}, seqs: []int{0, 0}, pos: []int32{1, 2}, expected: []float32{7, 8, 3, 4, 4}, - expectedShape: []int{1, 1, 5}, + expectedShape: []int{5, 1, 1}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0}, }, } @@ -187,11 +187,11 @@ func TestDefrag(t *testing.T) { { name: "FirstBatch", in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - inShape: []int{1, 1, 16}, + inShape: []int{16, 1, 1}, seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - expectedShape: []int{1, 1, 16}, + expectedShape: []int{16, 1, 1}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, }, } @@ -212,11 +212,11 @@ func TestDefrag(t *testing.T) { { name: "Defrag", in: []float32{17, 18, 19}, - inShape: []int{1, 1, 3}, + inShape: []int{3, 1, 1}, seqs: []int{0, 0, 0}, pos: []int32{16, 17, 18}, expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19}, - expectedShape: []int{1, 1, 16}, + expectedShape: []int{16, 1, 1}, expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, }, } @@ -235,11 +235,11 @@ func TestCopy(t *testing.T) { { name: "FirstBatch", in: []float32{1, 2, 3, 4}, - inShape: []int{1, 1, 4}, + inShape: []int{4, 1, 1}, seqs: []int{0, 0, 0, 0}, pos: []int32{0, 1, 2, 3}, expected: []float32{1, 2, 3, 4}, - expectedShape: []int{1, 1, 4}, + expectedShape: []int{4, 1, 1}, expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, }, } @@ -252,11 +252,11 @@ func TestCopy(t *testing.T) { { name: "Copy", in: []float32{5, 6}, - inShape: []int{1, 1, 2}, + inShape: []int{2, 1, 1}, seqs: []int{1, 1}, pos: []int32{3, 4}, expected: []float32{1, 2, 3, 4, 5, 6}, - expectedShape: []int{1, 1, 6}, + expectedShape: []int{6, 1, 1}, expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, }, } @@ -365,6 +365,9 @@ func (c *testContext) MaxGraphNodes() int { func (c *testContext) Close() {} +// TODO remove this before merging - temporary debugging aid +func (c *testContext) Abort(t ml.Tensor) {} + type testTensor struct { dtype ml.DType elementSize int @@ -378,7 +381,8 @@ func (t *testTensor) Dim(n int) int { func (t *testTensor) Stride(n int) int { stride := t.elementSize - for i := range n { + // Reverse to mimic ggml's row order impl + for i := len(t.shape) - 1; i > n; i-- { stride *= t.shape[i] } @@ -473,7 +477,7 @@ func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { panic("not implemented") } -func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { +func (t *testTensor) View(ctx ml.Context, offset int, shape []int, stride []int) ml.Tensor { offset /= t.elementSize var s []int @@ -481,8 +485,8 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { switch len(shape) { case 1: s = []int{shape[0]} - case 5: - s = []int{shape[0], shape[2], shape[4]} + case 3: + s = []int{shape[0], shape[1], shape[2]} default: panic("unsupported number of dimensions") } @@ -496,7 +500,32 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor { - panic("not implemented") + switch len(t.shape) { + case 2: + sh := make([]int, len(t.shape)) + size := 1 + for i := range sh { + sh[i] = t.shape[shape[i]] + size *= sh[i] + } + data := make([]float32, size) + dn := 0 + for j := range t.shape[1] { + for i := range t.shape[0] { + offset := i*t.shape[1] + j + data[dn] = t.data[offset] + dn++ + } + } + return &testTensor{ + dtype: t.dtype, + elementSize: t.elementSize, + data: data, + shape: sh, + } + default: + panic("not implemented") + } } func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor { diff --git a/ml/backend.go b/ml/backend.go index c63c73d46..134317750 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -97,6 +97,10 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) { type Context interface { Empty(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor + + // TODO - the (Tensor, error) return pattern makes this impossible to + // one-line in cases where we need to pass a scalar into a function that + // requires a Tensor leading to overly verbose impls. Consider a Must* API. FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error) @@ -113,6 +117,9 @@ type Context interface { // Layer returns a context appropriate for creating intermediate tensors Layer(int) Context + + // TODO remove this before merging - temporary debugging aid + Abort(Tensor) // Evaluate the graph up to this point, retrieve the data from the tensor and dump it to a json file for comparison } type Tensor interface { @@ -130,7 +137,7 @@ type Tensor interface { Mulmat(ctx Context, t2 Tensor) Tensor MulmatFullPrec(ctx Context, t2 Tensor) Tensor - Softmax(ctx Context) Tensor + Softmax(ctx Context) Tensor // TODO axis parameter? LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor Scale(ctx Context, s float64) Tensor @@ -145,7 +152,7 @@ type Tensor interface { SILU(ctx Context) Tensor Reshape(ctx Context, shape ...int) Tensor - View(ctx Context, offset int, shape ...int) Tensor + View(ctx Context, offset int, shape, stride []int) Tensor Permute(ctx Context, shape ...int) Tensor Contiguous(ctx Context) Tensor Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor @@ -298,3 +305,16 @@ const ( DTypeQ40 DTypeI32 ) + +func (dt DType) String() string { + switch dt { + case DTypeF32: + return "float32" + case DTypeF16: + return "float16" + case DTypeI32: + return "int32" + default: + return "unknown" + } +} diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 03b9acb32..df2679a36 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -9,12 +9,15 @@ package ggml import "C" import ( + "encoding/json" "errors" "fmt" "io" "log/slog" "maps" + "math" "os" + "runtime/debug" "slices" "strconv" "strings" @@ -24,10 +27,13 @@ import ( "github.com/ollama/ollama/format" fs "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml" - ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" "golang.org/x/sync/errgroup" + + ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ) +var rev = []C.int{3, 2, 1, 0} + func devices() []*C.struct_ggml_backend_device { ggml.OnceLoad() ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count()) @@ -65,7 +71,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { } slog.Info( - "", + "initializing GGML backend", "architecture", meta.KV().Architecture(), "file_type", meta.KV().FileType(), "name", meta.KV().String("general.name"), @@ -396,7 +402,7 @@ func (b *Backend) Config() ml.Config { func (b *Backend) Get(name string) ml.Tensor { if t, ok := b.tensors[name]; ok { - return &Tensor{b: b, t: t} + return &Tensor{b: b, t: t, nDims: int(C.ggml_n_dims(t))} } return nil @@ -529,7 +535,7 @@ func pad(length, pad C.size_t) C.size_t { return ((length + pad - 1) / pad) * pad } -func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { +func (c Context) newTensor(dtype ml.DType, rshape []int) ml.Tensor { if c.buft == nil { panic("set Input, Output, or Layer before creating tensors") } @@ -550,24 +556,31 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { panic("unsupported dtype") } - if len(shape) < 1 || shape[0] == 0 { + if len(rshape) < 1 || rshape[0] == 0 { var shape C.int64_t = 0 - return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)} - } else if len(shape) > 4 { + return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape), nDims: 1} + } else if len(rshape) > 4 { panic("unsupported number of dimensions") } - for _, dim := range shape { + for _, dim := range rshape { if dim < 1 { panic("invalid shape") } } + // Inverted + shape := make([]int, len(rshape)) + i := len(rshape) - 1 + for _, dim := range rshape { + shape[i] = dim + i-- + } t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape)) size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft)) b := C.ggml_backend_buft_alloc_buffer(c.buft, size) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) - return &Tensor{b: c.b, t: t} + return &Tensor{b: c.b, t: t, nDims: len(shape)} } func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { @@ -631,8 +644,14 @@ func (c *Context) Close() { } type Tensor struct { - b *Backend - t *C.struct_ggml_tensor + b *Backend + t *C.struct_ggml_tensor + + // keep track of the number of dimensions + // Since we reverse the shape, GGML considers a trailing "1" dimension as not present + // and we can't actually trust the output of ggml_n_dims + nDims int + sync func() } @@ -641,23 +660,35 @@ func (t *Tensor) LogValue() slog.Value { slog.String("name", C.GoString(C.ggml_get_name(t.t))), slog.String("type", C.GoString(C.ggml_type_name(t.t._type))), slog.Any("shape", t.Shape()), + slog.Any("underlying shape", t.t.ne), + slog.Any("underlying stride", t.t.nb), ) } func (t *Tensor) Dim(n int) int { - return int(t.t.ne[n]) + if t.nDims == 0 { + // If this hits we likely forgot to copy the dimension to the returned tensor in some operation + panic("zero dimension tensor") + } + r := rev[4-t.nDims:] + return int(t.t.ne[r[n]]) } func (t *Tensor) Stride(n int) int { - return int(t.t.nb[n]) + if t.nDims == 0 { + slog.Error("Stride", "tensor", t, "dim", n) + panic("zero dimension tensor") + } + r := rev[4-t.nDims:] + s := int(t.t.nb[r[n]]) + return s } func (t *Tensor) Shape() []int { - shape := make([]int, C.ggml_n_dims(t.t)) + shape := make([]int, t.nDims) for i := range shape { shape[i] = t.Dim(i) } - return shape } @@ -702,8 +733,9 @@ func (t *Tensor) DType() ml.DType { func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + b: t.b, + t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + nDims: t.nDims, } } @@ -717,29 +749,38 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)), + b: t.b, + t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)), + nDims: max(t.nDims, t2.(*Tensor).nDims), } } func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_cont(ctx.(*Context).ctx, t.t), + b: t.b, + t: C.ggml_cont(ctx.(*Context).ctx, t.t), + nDims: t.nDims, } } func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + b: t.b, + t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + nDims: t.nDims, // TODO should this be max(t.nDims, t2.nDims)? } } func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + if t.t.ne[0] != t2.(*Tensor).t.ne[0] { + slog.Error("incorrect tensor shapes for Mulmat", "t", t, "t2", t2) + panic("malformed tensors passed to Mulmat") + } + r := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t) return &Tensor{ - b: t.b, - t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + b: t.b, + t: r, + nDims: max(t.nDims, t2.(*Tensor).nDims), } } @@ -748,13 +789,14 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32) return &Tensor{ - b: t.b, - t: mul, + b: t.b, + t: mul, + nDims: max(t.nDims, t2.(*Tensor).nDims), } } func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { - tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)), nDims: t.nDims}).Mul(ctx, w) if b != nil { tt = tt.Add(ctx, b) } @@ -763,17 +805,28 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso } func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { - return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps)), nDims: t.nDims}).Mul(ctx, w) } func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { if len(shape) != 4 { panic("expected 4 dimensions") } - + var r *C.struct_ggml_tensor + switch t.nDims { + case 1: + r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])) + case 2: + r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[1]), C.int(shape[0]), C.int(shape[2]), C.int(shape[3])) + case 3: + r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[2]), C.int(shape[1]), C.int(shape[0]), C.int(shape[3])) + default: + r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[3]), C.int(shape[2]), C.int(shape[1]), C.int(shape[0])) + } return &Tensor{ - b: t.b, - t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), + b: t.b, + t: r, + nDims: t.nDims, } } @@ -781,48 +834,132 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor { if len(shape) != 4 { panic("expected 4 dimensions") } - - return &Tensor{ - b: t.b, - t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), + rshape := []C.int{0, 1, 2, 3} + switch t.nDims { + case 2: + // TODO make sure this isn't wonky... + rshape[0] = rev[2:][shape[1]] + rshape[1] = rev[2:][shape[0]] + case 3: + // TODO has to be a better way... + rshape[0] = C.int(shape[0]) + rshape[1] = C.int(shape[1]) + rshape[2] = C.int(shape[2]) + switch shape[0]*100 + shape[1]*10 + shape[2] { + case 21: + rshape[0], rshape[1], rshape[2] = 1, 0, 2 + case 102: + rshape[0], rshape[1], rshape[2] = 0, 2, 1 + } + case 4: + // TODO has to be a better way... + rshape[0] = C.int(shape[0]) + rshape[1] = C.int(shape[1]) + rshape[2] = C.int(shape[2]) + rshape[3] = C.int(shape[3]) + switch shape[0]*1000 + shape[1]*100 + shape[2]*10 + shape[3] { + case 132: + rshape[0], rshape[1], rshape[2], rshape[3] = 1, 0, 2, 3 + case 231: + rshape[0], rshape[1], rshape[2], rshape[3] = 1, 2, 0, 3 + case 312: + rshape[0], rshape[1], rshape[2], rshape[3] = 2, 0, 1, 3 + case 321: + rshape[0], rshape[1], rshape[2], rshape[3] = 2, 1, 0, 3 + case 1023: + rshape[0], rshape[1], rshape[2], rshape[3] = 0, 1, 3, 2 + case 1203: + rshape[0], rshape[1], rshape[2], rshape[3] = 0, 2, 3, 1 + case 1302: + rshape[0], rshape[1], rshape[2], rshape[3] = 2, 0, 3, 1 + case 1320: + rshape[0], rshape[1], rshape[2], rshape[3] = 2, 1, 3, 0 + case 2013: + rshape[0], rshape[1], rshape[2], rshape[3] = 0, 3, 1, 2 + case 2031: + rshape[0], rshape[1], rshape[2], rshape[3] = 1, 3, 0, 2 + case 2103: + rshape[0], rshape[1], rshape[2], rshape[3] = 0, 3, 2, 1 + case 2130: + rshape[0], rshape[1], rshape[2], rshape[3] = 1, 3, 2, 0 + case 3021: + rshape[0], rshape[1], rshape[2], rshape[3] = 3, 1, 0, 2 + case 3102: + rshape[0], rshape[1], rshape[2], rshape[3] = 3, 0, 2, 1 + } } + + r := &Tensor{ + b: t.b, + t: C.ggml_permute(ctx.(*Context).ctx, t.t, rshape[0], rshape[1], rshape[2], rshape[3]), + nDims: t.nDims, + } + return r } func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + b: t.b, + t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + nDims: t.nDims, } } func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + r := &Tensor{ + b: t.b, + t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), + nDims: t.nDims, } + return r } func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { + // GGML does not handle -1 natively + for i, sh := range shape { + if sh == -1 { + totalElems := 1 + for d := range t.nDims { + totalElems *= int(t.t.ne[d]) + } + otherElems := 1 + for _, osh := range shape { + if osh != -1 { + otherElems *= osh + } + } + if otherElems > totalElems { + slog.Error("Invalid request", "req", shape, "actual", t.Shape(), "totalElems", totalElems, "otherElems", otherElems) + panic("impossible -1 shape request") + } + shape[i] = int(float64(totalElems) / float64(otherElems)) + break + } + } switch len(shape) { case 1: return &Tensor{ - b: t.b, - t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])), + b: t.b, + t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])), + nDims: len(shape), } case 2: return &Tensor{ - b: t.b, - t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])), + b: t.b, + t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[1]), C.int64_t(shape[0])), + nDims: len(shape), } case 3: return &Tensor{ - b: t.b, - t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])), + b: t.b, + t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0])), + nDims: len(shape), } case 4: return &Tensor{ - b: t.b, - t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])), + b: t.b, + t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[3]), C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0])), + nDims: len(shape), } default: panic("unsupported number of dimensions") @@ -831,22 +968,25 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)), + b: t.b, + t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)), + nDims: t.nDims, } } func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_soft_max(ctx.(*Context).ctx, t.t), + b: t.b, + t: C.ggml_soft_max(ctx.(*Context).ctx, t.t), + nDims: t.nDims, } } func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t), + b: t.b, + t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t), + nDims: t.nDims, } } @@ -856,41 +996,50 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ - b: t.b, - t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), + b: t.b, + t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[3]), C.int(shape[2]), C.int(shape[1]), C.int(shape[0])), + nDims: t.nDims, // TODO is this right? } } -func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { +func (t *Tensor) View(ctx ml.Context, offset int, shape, stride []int) ml.Tensor { + if len(stride)+1 != len(shape) { + panic(fmt.Sprintf("malformed view request: shape=%v stride=%v", shape, stride)) + } + switch len(shape) { case 1: + return &Tensor{ + b: t.b, + t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)), + nDims: 1, + } + case 2: return &Tensor{ b: t.b, - t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)), + t: C.ggml_view_2d(ctx.(*Context).ctx, t.t, + C.int64_t(shape[1]), C.int64_t(shape[0]), + C.size_t(stride[0]), + C.size_t(offset)), + nDims: 2, } case 3: - return &Tensor{ - b: t.b, - t: C.ggml_view_2d(ctx.(*Context).ctx, t.t, - C.int64_t(shape[0]), C.int64_t(shape[2]), - C.size_t(shape[1]), - C.size_t(offset)), - } - case 5: return &Tensor{ b: t.b, t: C.ggml_view_3d(ctx.(*Context).ctx, t.t, - C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), - C.size_t(shape[1]), C.size_t(shape[3]), + C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0]), + C.size_t(stride[1]), C.size_t(stride[0]), C.size_t(offset)), + nDims: 3, } - case 7: + case 4: return &Tensor{ b: t.b, t: C.ggml_view_4d(ctx.(*Context).ctx, t.t, - C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]), - C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]), + C.int64_t(shape[3]), C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0]), + C.size_t(stride[2]), C.size_t(stride[1]), C.size_t(stride[0]), C.size_t(offset)), + nDims: 4, } default: panic("unsupported number of dimensions") @@ -906,14 +1055,13 @@ const ( func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { if ropeFactors == nil { - ropeFactors = &Tensor{b: t.b} + ropeFactors = &Tensor{b: t.b, nDims: 0} } dequant := t.t if C.ggml_is_quantized(t.t._type) { dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) } - return &Tensor{ b: t.b, t: C.ggml_rope_ext( @@ -928,34 +1076,39 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi 32., // YaRN beta_fast 1., // YaRN beta_slow ), + nDims: t.nDims, } } func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), + b: t.b, + t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), + nDims: t.nDims, } } func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), + b: t.b, + t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), + nDims: t.nDims, } } func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)), + b: t.b, + t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)), + nDims: t.nDims, } } func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { return &Tensor{ - b: t.b, - t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)), + b: t.b, + t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)), + nDims: t.nDims, } } @@ -970,7 +1123,7 @@ func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) m panic("unsupported number of dimensions") } - return &Tensor{b: t.b, t: tt} + return &Tensor{b: t.b, t: tt, nDims: t.nDims} } func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor { @@ -979,23 +1132,60 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T kqMask = mask.(*Tensor).t } - query := t.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) + query := t.Permute(ctx, 1, 0, 2, 3) + key = key.Permute(ctx, 1, 0, 2, 3) if t.b.flashAttention { - value = value.Permute(ctx, 0, 2, 1, 3) + value = value.Permute(ctx, 1, 0, 2, 3) kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32) - return &Tensor{b: t.b, t: kqv} + return &Tensor{b: t.b, t: kqv, nDims: t.nDims} } else { kq := key.MulmatFullPrec(ctx, query) kq = &Tensor{ - b: t.b, - t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), + b: t.b, + t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), + nDims: t.nDims, } kqv := value.Mulmat(ctx, kq) - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + return kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) } } + +// TODO remove this before merging - temporary debugging aid +func (c *Context) Abort(t ml.Tensor) { + // Hack to make sure we're f32, otherwise r.Floats will fail due to short read + if t.(*Tensor).t._type != C.GGML_TYPE_F32 { + t.(*Tensor).t = C.ggml_cast(c.ctx, t.(*Tensor).t, C.GGML_TYPE_F32) + } + c.Forward(t) + c.Compute(t) + f32 := t.Floats() + // Convert [-]Inf to serializable values + for i, v := range f32 { + if v > math.MaxFloat32 { + f32[i] = math.MaxFloat32 + } + if v < -math.SmallestNonzeroFloat32 { + f32[i] = -math.MaxFloat32 + } + } + debug.PrintStack() + + filename := "ggml.json" + slog.Info("Writing tensors to", "filename", filename) + f, err := os.Create(filename) + if err != nil { + panic(err) + } + defer f.Close() + encoder := json.NewEncoder(f) + err = encoder.Encode(f32) + if err != nil { + panic(err) + } + + os.Exit(1) +} diff --git a/ml/nn/attention.go b/ml/nn/attention.go index a3f43a1ea..91d4368f5 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -12,9 +12,9 @@ import ( // // Parameters: // - ctx: Context for tensor operations -// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q] -// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only -// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only +// - query: Query tensor (Q) with shape [seq_len_q, heads, d_k] +// - key: Key tensor (K) with shape [seq_len_k, kv_heads, d_k], can be nil to read from cache only +// - value: Value tensor (V) with shape [seq_len_k, kv_heads, d_v], can be nil to read from cache only // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension // - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value // @@ -23,16 +23,16 @@ import ( // Attention output with shape [d_v, heads, seq_len_q] func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { if key != nil && value != nil { - if query.Dim(0) != key.Dim(0) { - panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + if query.Dim(2) != key.Dim(2) { + panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(2), key.Dim(2))) } if key.Dim(1) != value.Dim(1) { panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))) } - if key.Dim(2) != value.Dim(2) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + if key.Dim(0) != value.Dim(0) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(0), value.Dim(0))) } if cache != nil { @@ -52,8 +52,8 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) } else { - query = query.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) + query = query.Permute(ctx, 1, 0, 2, 3) + key = key.Permute(ctx, 1, 0, 2, 3) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) kq := key.MulmatFullPrec(ctx, query) @@ -65,6 +65,6 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kq = kq.Softmax(ctx) kqv := value.Mulmat(ctx, kq) - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + return kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) } } diff --git a/model/README.md b/model/README.md new file mode 100644 index 000000000..367227810 --- /dev/null +++ b/model/README.md @@ -0,0 +1,62 @@ +# Ollama Models + +Large Language Models in Ollama are defined in the Go programming language within this directory. + + +## Model Implementation Guide + +Ollama supports multiple backends, and provides an astracted interface for model implementers. [Backend API ](../ml/backend.go) + +This API is designed to be similar to other popular python libraries such as +PyTorch, with row-major tensors and a forward function that takes a sequence of +inputs. + +Use an existing model as an initial reference, such as [llama](./models/llama/) + +Cheatsheet: + + + + + + + + + + + + + + + + + + + + + + + + + + +
PyTorchOllama
torch.zeros((2, 2))ctx.Zeros(ml.DTypeF32, 2, 2)
tensor.view((2, 2))t.Reshape(ctx, 2, 2)
torch.permute(t1, (1, 2, 3))t1.Permute(ctx, 1, 2, 3)
torch.add(t1, t2)t1.Add(ctx, t2)
+ +```python +class Attention(nn.Module): + def __call__(self, ...): + ... +``` + + + +```go +func (sa *SelfAttention) Forward(ctx ml.Context, + hiddenState, positionIDs ml.Tensor, + cache kvcache.Cache, + opts *Options) ml.Tensor { + ... +} +``` + +
\ No newline at end of file diff --git a/model/model.go b/model/model.go index 53e47add9..bf911ea97 100644 --- a/model/model.go +++ b/model/model.go @@ -306,3 +306,14 @@ func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) { return t, nil } + +func ArangeF32(start, end, step float32) []float32 { + if step == 0 || start >= end { + return nil + } + var res []float32 + for i := float32(start); i < end; i += step { + res = append(res, i) + } + return res +} diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 2b8597c42..6087f8d79 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -77,11 +77,11 @@ type SelfAttention struct { } func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { - batchSize := hiddenState.Dim(1) + batchSize := hiddenState.Dim(0) ropeType := uint32(2) q := sa.Query.Forward(ctx, hiddenState) - q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) + q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen) q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) if opts.largeModelScaling { @@ -91,17 +91,17 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } k := sa.Key.Forward(ctx, hiddenState) - k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) + k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen) k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) - v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) + v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen) cache.Put(ctx, k, v) k, v, mask := cache.Get(ctx) - q = q.Permute(ctx, 0, 2, 1, 3) - k = k.Permute(ctx, 0, 2, 1, 3) + q = q.Permute(ctx, 1, 0, 2, 3) + k = k.Permute(ctx, 1, 0, 2, 3) v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) kq := k.Mulmat(ctx, q) @@ -115,8 +115,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten kq = kq.Softmax(ctx) kqv := v.Mulmat(ctx, kq) - kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) + kqv = kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + kqv = kqv.Reshape(ctx, batchSize, opts.attnValLen*opts.numHeads) return sa.Output.Forward(ctx, kqv) } diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 32ad80f43..53c308eb5 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -35,15 +35,15 @@ type MultiModalProjector struct { } func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor { - l := visionOutputs.Dim(0) + l := visionOutputs.Dim(1) visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) patchesPerImage := imageSize / patchSize - visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l) + visionOutputs = visionOutputs.Reshape(ctx, l, patchesPerImage, patchesPerImage) kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage))) visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0) - visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l) + visionOutputs = visionOutputs.Reshape(ctx, l, visionOutputs.Dim(2)*visionOutputs.Dim(1)) visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps) @@ -98,9 +98,9 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er } pixelValues, err := ctx.Input().FromFloatSlice(f32s, - m.ImageProcessor.imageSize, - m.ImageProcessor.imageSize, m.ImageProcessor.numChannels, + m.ImageProcessor.imageSize, + m.ImageProcessor.imageSize, ) if err != nil { return nil, err @@ -121,13 +121,13 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { inputMultimodal := inp.Multimodal.(ml.Tensor) result = append(result, - input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + input.Input{Token: 108, SameBatch: inputMultimodal.Dim(0) + 3}, // "\n\n" input.Input{Token: 255999}, // """ input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder ) // add image token placeholders - result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) + result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(0)-1)...) result = append(result, input.Input{Token: 256000}, // diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 567f65a5e..6a57fc907 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -85,7 +85,7 @@ type TextSelfAttention struct { } func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { - batchSize := hiddenState.Dim(1) + batchSize := hiddenState.Dim(0) ropeType := uint32(2) ropeBase := opts.ropeLocalBase @@ -94,7 +94,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } q := sa.Query.Forward(ctx, hiddenState) - q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) + q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen) q = sa.QueryNorm.Forward(ctx, q, opts.eps) q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) @@ -105,16 +105,16 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos } k := sa.Key.Forward(ctx, hiddenState) - k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) + k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen) k = sa.KeyNorm.Forward(ctx, k, opts.eps) k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) - v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) + v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen) scaleFactor := 1.0 kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) + kqv = kqv.Reshape(ctx, batchSize, opts.attnValLen*opts.numHeads) return sa.Output.Forward(ctx, kqv) } @@ -179,9 +179,9 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor var except []int for _, image := range opts.Multimodal { visionOutputs := image.Multimodal.(ml.Tensor) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(0), []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)}, nil))) - for i := range visionOutputs.Dim(1) { + for i := range visionOutputs.Dim(0) { except = append(except, image.Index+i) } } diff --git a/model/models/gemma3/model_vision.go b/model/models/gemma3/model_vision.go index 94aa27bd7..be5956b01 100644 --- a/model/models/gemma3/model_vision.go +++ b/model/models/gemma3/model_vision.go @@ -7,8 +7,6 @@ import ( "github.com/ollama/ollama/ml/nn" ) -var batchSize int = 1 - type VisionSelfAttention struct { Query *nn.Linear `gguf:"attn_q"` Key *nn.Linear `gguf:"attn_k"` @@ -23,12 +21,12 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op key := sa.Key.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState) - query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) - key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) - value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) + query = query.Reshape(ctx, query.Dim(0), opts.numHeads, headDim) + key = key.Reshape(ctx, key.Dim(0), opts.numHeads, headDim) + value = value.Reshape(ctx, value.Dim(0), opts.numHeads, headDim) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) - attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + attention = attention.Reshape(ctx, attention.Dim(0), opts.hiddenSize) hiddenState = sa.Output.Forward(ctx, attention) return hiddenState @@ -88,7 +86,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize) hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) - hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) + hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPatches) hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) positions := make([]int32, numPatches) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 19a2ab8c4..452edadd9 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -74,24 +74,24 @@ type SelfAttention struct { } func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { - batchSize := hiddenState.Dim(1) + batchSize := hiddenState.Dim(0) // TODO Consider renaming "L" as this is the sequence length, not batch size headDim := opts.hiddenSize / opts.numHeads ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) - q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) + q = q.Reshape(ctx, batchSize, opts.numHeads, -1) q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) k := sa.Key.Forward(ctx, hiddenState) - k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + k = k.Reshape(ctx, batchSize, opts.numKVHeads, -1) k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) - v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + v = v.Reshape(ctx, batchSize, opts.numKVHeads, -1) scaleFactor := 1.0 / math.Sqrt(float64(headDim)) kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) + kqv = kqv.Reshape(ctx, batchSize, -1) return sa.Output.Forward(ctx, kqv) } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index fa4d570ca..804d8d293 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -78,10 +78,10 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er } pixelValues, err := ctx.Input().FromFloatSlice(f32s, - m.ImageProcessor.imageSize, - m.ImageProcessor.imageSize, - m.ImageProcessor.numChannels, m.ImageProcessor.maxNumTiles, + m.ImageProcessor.numChannels, + m.ImageProcessor.imageSize, + m.ImageProcessor.imageSize, ) if err != nil { return nil, err diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 1cf30d89b..cc2d0b26e 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -18,24 +18,24 @@ type TextSelfAttention struct { } func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { - batchSize := hiddenState.Dim(1) + batchSize := hiddenState.Dim(0) headDim := opts.hiddenSize / opts.numHeads ropeType := uint32(0) query := sa.Query.Forward(ctx, hiddenState) - query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) + query = query.Reshape(ctx, batchSize, opts.numHeads, headDim) query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) key := sa.Key.Forward(ctx, hiddenState) - key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + key = key.Reshape(ctx, batchSize, opts.numKVHeads, headDim) key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) value := sa.Value.Forward(ctx, hiddenState) - value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + value = value.Reshape(ctx, batchSize, opts.numKVHeads, headDim) scaleFactor := 1.0 / math.Sqrt(float64(headDim)) attention := nn.Attention(ctx, query, key, value, scaleFactor, cache) - attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) + attention = attention.Reshape(ctx, batchSize, opts.hiddenSize) return sa.Output.Forward(ctx, attention) } @@ -99,23 +99,23 @@ type TextCrossAttention struct { } func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { - batchSize := hiddenState.Dim(1) + batchSize := hiddenState.Dim(0) headDim := opts.hiddenSize / opts.numHeads query := ca.Query.Forward(ctx, hiddenState) - query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) + query = query.Reshape(ctx, batchSize, opts.numHeads, headDim) query = ca.QueryNorm.Forward(ctx, query, opts.eps) var key, value ml.Tensor if crossAttentionStates != nil { - numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) + numVisionTokens, numTiles := crossAttentionStates.Dim(2), crossAttentionStates.Dim(1) key = ca.Key.Forward(ctx, crossAttentionStates) - key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) + key = key.Reshape(ctx, numVisionTokens*numTiles, opts.numKVHeads, headDim) key = ca.KeyNorm.Forward(ctx, key, opts.eps) value = ca.Value.Forward(ctx, crossAttentionStates) - value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) + value = value.Reshape(ctx, numVisionTokens*numTiles, opts.numKVHeads, headDim) cache.Put(ctx, key, value) } @@ -124,8 +124,8 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - query = query.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) + query = query.Permute(ctx, 1, 0, 2, 3) + key = key.Permute(ctx, 1, 0, 2, 3) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) kq := key.MulmatFullPrec(ctx, query) @@ -134,8 +134,8 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio kq = kq.Softmax(ctx) kqv := value.Mulmat(ctx, kq) - attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) + attention := kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + attention = attention.Reshape(ctx, batchSize, opts.hiddenSize) return ca.Output.Forward(ctx, attention) } diff --git a/model/models/mllama/model_vision.go b/model/models/mllama/model_vision.go index ac777f051..d1fbb6e10 100644 --- a/model/models/mllama/model_vision.go +++ b/model/models/mllama/model_vision.go @@ -23,25 +23,25 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op headDim := opts.hiddenSize / opts.numHeads query := sa.Query.Forward(ctx, hiddenState) - query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) + query = query.Reshape(ctx, batchSize, query.Dim(1), opts.numHeads, headDim) query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key := sa.Key.Forward(ctx, hiddenState) - key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) + key = key.Reshape(ctx, batchSize, key.Dim(1), opts.numHeads, headDim) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) value := sa.Value.Forward(ctx, hiddenState) - value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + value = value.Reshape(ctx, batchSize, value.Dim(1), opts.numHeads, headDim) + value = value.Permute(ctx, 0, 2, 3, 1).Contiguous(ctx) scores := key.Mulmat(ctx, query) scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) scores = scores.Softmax(ctx) attention := value.Mulmat(ctx, scores) - attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize) + attention = attention.Reshape(ctx, batchSize, opts.numHeads, attention.Dim(2), headDim) attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) + attention = attention.Reshape(ctx, batchSize, attention.Dim(1), opts.hiddenSize) hiddenState = sa.Output.Forward(ctx, attention) if sa.Gate != nil { @@ -99,7 +99,7 @@ func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermedi var intermediateHiddenStates []ml.Tensor for i, layer := range e.Layers { if slices.Contains(intermediateLayersIndices, uint32(i)) { - intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...)) + intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append(hiddenState.Shape(), 1)...)) } hiddenState = layer.Forward(ctx, hiddenState, opts) @@ -115,7 +115,7 @@ type PrecomputedAspectRatioEmbedding struct { func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { embeddings := e.Embedding.Forward(ctx, aspectRatioIDs) - embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles) + embeddings = embeddings.Reshape(ctx, opts.numTiles, 1, opts.hiddenSize) if e.Gate != nil { embeddings = embeddings.Mul(ctx, e.Gate) } @@ -140,7 +140,7 @@ func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, posi hiddenState = hiddenState.Add(ctx, positionEmbedding) tilePositionEmbedding := e.TilePositionEmbedding.Forward(ctx, aspectRatioIDs) - tilePositionEmbedding = tilePositionEmbedding.Reshape(ctx, opts.hiddenSize, numPositions, opts.numTiles) + tilePositionEmbedding = tilePositionEmbedding.Reshape(ctx, opts.numTiles, numPositions, opts.hiddenSize) if e.TilePositionEmbeddingGate != nil { tilePositionEmbedding = tilePositionEmbedding.Mul(ctx, e.TilePositionEmbeddingGate) } @@ -181,8 +181,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa } hiddenState := m.PatchEmbeddings.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) - hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize, m.numTiles) - hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + hiddenState = hiddenState.Reshape(ctx, m.numTiles, m.hiddenSize, numPatches) + hiddenState = hiddenState.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions) hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1) @@ -193,23 +193,23 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa numPaddingPatches := 8 - (hiddenState.Dim(1)%8)%8 hiddenState = hiddenState.Pad(ctx, 0, numPaddingPatches, 0, 0) - hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), hiddenState.Dim(1)*hiddenState.Dim(2), batchSize) + hiddenState = hiddenState.Reshape(ctx, batchSize, hiddenState.Dim(1)*hiddenState.Dim(0), hiddenState.Dim(2)) hiddenState, intermediateHiddenStates := m.Transformer.Forward(ctx, hiddenState, m.intermediateLayersIndices, m.VisionModelOptions) hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps) - hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize) + hiddenState = hiddenState.Reshape(ctx, batchSize, m.numTiles, numPositions+numPaddingPatches, m.hiddenSize) hiddenState = m.PostTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions) - hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, m.numTiles*(numPositions+numPaddingPatches), batchSize) + hiddenState = hiddenState.Reshape(ctx, batchSize, m.numTiles*(numPositions+numPaddingPatches), m.hiddenSize) hiddenState, _ = m.GlobalTransformer.Forward(ctx, hiddenState, nil, m.VisionModelOptions) hiddenStates := intermediateHiddenStates[0].Stack(ctx, 0, intermediateHiddenStates[1:]...) - hiddenStates = hiddenStates.Reshape(ctx, len(intermediateHiddenStates)*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize) - hiddenStates = hiddenStates.Unpad(ctx, 0, numPaddingPatches, 0, 0) + hiddenStates = hiddenStates.Reshape(ctx, batchSize, m.numTiles, numPositions+numPaddingPatches, len(intermediateHiddenStates)*m.hiddenSize) + hiddenStates = hiddenStates.Unpad(ctx, 0, 0, numPaddingPatches, 0) - hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize) - hiddenState = hiddenState.Unpad(ctx, 0, numPaddingPatches, 0, 0) + hiddenState = hiddenState.Reshape(ctx, batchSize, m.numTiles, numPositions+numPaddingPatches, m.hiddenSize) + hiddenState = hiddenState.Unpad(ctx, 0, 0, numPaddingPatches, 0) return hiddenState.Concat(ctx, hiddenStates, 0) }