mirror of
https://github.com/ollama/ollama.git
synced 2025-04-07 03:18:24 +02:00
Row Order model definitions
Functional implementation on the latest backend and caching code Still has some debugging that needs rebasing/cleanup one unit test fails which still needs work...
This commit is contained in:
parent
e27e4a3c1b
commit
7b3c3135de
2
.gitignore
vendored
2
.gitignore
vendored
@ -14,3 +14,5 @@ __debug_bin*
|
||||
llama/build
|
||||
llama/vendor
|
||||
/ollama
|
||||
build/
|
||||
|
||||
|
@ -251,7 +251,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
maskTensor, err := ctx.Input().FromFloatSlice(mask, batchSize, length)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -271,12 +271,12 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
kHeadDim := key.Dim(2)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
rowSize := key.Stride(0)
|
||||
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
||||
kSrcView := key.View(ctx, rowSize*src, []int{kHeadDim * numKVHeads * len}, nil)
|
||||
kDstView := key.View(ctx, rowSize*dst, []int{kHeadDim * numKVHeads * len}, nil)
|
||||
|
||||
value := c.values[i]
|
||||
var vSrcView, vDstView ml.Tensor
|
||||
@ -284,14 +284,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
vSrcView = value.View(ctx, elemSize*src, []int{vHeadDim * numKVHeads, len}, []int{int(c.Capacity) * elemSize})
|
||||
vDstView = value.View(ctx, elemSize*dst, []int{vHeadDim * numKVHeads, len}, []int{int(c.Capacity) * elemSize})
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
vHeadDim := value.Dim(2)
|
||||
rowSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
||||
vSrcView = value.View(ctx, rowSize*src, []int{vHeadDim * numKVHeads * len}, nil)
|
||||
vDstView = value.View(ctx, rowSize*dst, []int{vHeadDim * numKVHeads * len}, nil)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
@ -430,45 +430,52 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
kHeadDim := key.Dim(2)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
cachedSize := c.curMask.Dim(0)
|
||||
rowSize := key.Stride(0)
|
||||
cachedSize := c.curMask.Dim(1)
|
||||
// slog.Info("Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "rowSize", rowSize, "cachedSize", cachedSize)
|
||||
|
||||
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
cachedSize,
|
||||
[]int{cachedSize, numKVHeads, kHeadDim},
|
||||
[]int{key.Stride(0), key.Stride(1)},
|
||||
)
|
||||
// slog.Info("Get", "key", key)
|
||||
// panic("XXX")
|
||||
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
elemSize := value.Stride(2)
|
||||
|
||||
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||
cachedSize, value.Stride(1),
|
||||
vHeadDim, value.Stride(2),
|
||||
numKVHeads,
|
||||
[]int{numKVHeads, vHeadDim, cachedSize},
|
||||
[]int{value.Stride(0), value.Stride(1)},
|
||||
)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
vHeadDim := value.Dim(2)
|
||||
rowSize := value.Stride(0)
|
||||
|
||||
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||
vHeadDim, value.Stride(1),
|
||||
numKVHeads, value.Stride(2),
|
||||
cachedSize,
|
||||
[]int{cachedSize, numKVHeads, vHeadDim},
|
||||
[]int{value.Stride(0), value.Stride(1)},
|
||||
)
|
||||
}
|
||||
|
||||
// TODO The mask changes from X,X to 1,X, and with the Row-order change
|
||||
// the 1 becomes trailing and messes up later operations
|
||||
// This isn't the right solution, but works around it...
|
||||
if c.curMask.Dim(1) == 1 {
|
||||
return key, value, c.curMask.Permute(ctx, 1, 0, 2, 3)
|
||||
}
|
||||
|
||||
return key, value, c.curMask
|
||||
}
|
||||
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
kHeadDim := key.Dim(0)
|
||||
vHeadDim := value.Dim(0)
|
||||
kHeadDim := key.Dim(2)
|
||||
vHeadDim := value.Dim(2)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
batchSize := key.Dim(0)
|
||||
|
||||
if c.curBatchSize != batchSize {
|
||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||
@ -479,29 +486,29 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
}
|
||||
|
||||
if _, ok := c.keys[c.curLayer]; !ok {
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), numKVHeads, kHeadDim)
|
||||
}
|
||||
|
||||
if _, ok := c.values[c.curLayer]; !ok {
|
||||
if c.config.PermutedV {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, int(c.Capacity))
|
||||
} else {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), numKVHeads, vHeadDim)
|
||||
}
|
||||
}
|
||||
|
||||
rowSize := c.keys[c.curLayer].Stride(2)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||
rowSize := c.keys[c.curLayer].Stride(0)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, []int{kHeadDim * numKVHeads * batchSize}, nil)))
|
||||
|
||||
if c.config.PermutedV {
|
||||
elemSize := c.values[c.curLayer].Stride(0)
|
||||
elemSize := c.values[c.curLayer].Stride(2)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, []int{vHeadDim * numKVHeads, batchSize}, []int{int(c.Capacity) * elemSize})))
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(2)
|
||||
rowSize := c.values[c.curLayer].Stride(0)
|
||||
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, []int{vHeadDim * numKVHeads * batchSize}, nil)))
|
||||
}
|
||||
}
|
||||
|
||||
@ -558,14 +565,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
kHeadDim := key.Dim(2)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
rowSize := key.Stride(0)
|
||||
|
||||
key = key.View(ctx, rowSize*seqRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
size,
|
||||
[]int{size, numKVHeads, kHeadDim},
|
||||
[]int{key.Stride(0), key.Stride(1)},
|
||||
)
|
||||
|
||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
|
@ -31,21 +31,21 @@ func TestStore(t *testing.T) {
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
inShape: []int{2, 3, 4},
|
||||
inShape: []int{4, 3, 2},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
|
||||
expectedShape: []int{2, 3, 4},
|
||||
expectedShape: []int{4, 3, 2},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{115, 215, 125, 225, 135, 235},
|
||||
inShape: []int{2, 3, 1},
|
||||
inShape: []int{1, 3, 2},
|
||||
seqs: []int{0},
|
||||
pos: []int32{4},
|
||||
expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
|
||||
expectedShape: []int{2, 3, 5},
|
||||
expectedShape: []int{5, 3, 2},
|
||||
expectedMask: []float32{0, 0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
@ -64,11 +64,11 @@ func TestSWA(t *testing.T) {
|
||||
{
|
||||
name: "SlidingWindow",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
inShape: []int{4, 1, 1},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedShape: []int{4, 1, 1},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
}
|
||||
@ -87,21 +87,21 @@ func TestSequences(t *testing.T) {
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
inShape: []int{4, 1, 1},
|
||||
seqs: []int{0, 0, 1, 1},
|
||||
pos: []int32{0, 1, 0, 1},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedShape: []int{4, 1, 1},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
{
|
||||
name: "SecondBatch",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
inShape: []int{2, 1, 1},
|
||||
seqs: []int{0, 1},
|
||||
pos: []int32{2, 2},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
expectedShape: []int{1, 1, 6},
|
||||
expectedShape: []int{6, 1, 1},
|
||||
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||
},
|
||||
}
|
||||
@ -122,11 +122,11 @@ func TestRemove(t *testing.T) {
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
inShape: []int{4, 1, 1},
|
||||
seqs: []int{0, 0, 1, 1},
|
||||
pos: []int32{0, 1, 0, 1},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedShape: []int{4, 1, 1},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
}
|
||||
@ -142,11 +142,11 @@ func TestRemove(t *testing.T) {
|
||||
{
|
||||
name: "RemoveEnd",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
inShape: []int{2, 1, 1},
|
||||
seqs: []int{0, 1},
|
||||
pos: []int32{1, 2},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
expectedShape: []int{1, 1, 6},
|
||||
expectedShape: []int{6, 1, 1},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
|
||||
},
|
||||
}
|
||||
@ -162,11 +162,11 @@ func TestRemove(t *testing.T) {
|
||||
{
|
||||
name: "RemoveMiddle",
|
||||
in: []float32{7, 8},
|
||||
inShape: []int{1, 1, 2},
|
||||
inShape: []int{2, 1, 1},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{1, 2},
|
||||
expected: []float32{7, 8, 3, 4, 4},
|
||||
expectedShape: []int{1, 1, 5},
|
||||
expectedShape: []int{5, 1, 1},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0},
|
||||
},
|
||||
}
|
||||
@ -187,11 +187,11 @@ func TestDefrag(t *testing.T) {
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
inShape: []int{1, 1, 16},
|
||||
inShape: []int{16, 1, 1},
|
||||
seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
|
||||
expectedShape: []int{1, 1, 16},
|
||||
expectedShape: []int{16, 1, 1},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
@ -212,11 +212,11 @@ func TestDefrag(t *testing.T) {
|
||||
{
|
||||
name: "Defrag",
|
||||
in: []float32{17, 18, 19},
|
||||
inShape: []int{1, 1, 3},
|
||||
inShape: []int{3, 1, 1},
|
||||
seqs: []int{0, 0, 0},
|
||||
pos: []int32{16, 17, 18},
|
||||
expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19},
|
||||
expectedShape: []int{1, 1, 16},
|
||||
expectedShape: []int{16, 1, 1},
|
||||
expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
@ -235,11 +235,11 @@ func TestCopy(t *testing.T) {
|
||||
{
|
||||
name: "FirstBatch",
|
||||
in: []float32{1, 2, 3, 4},
|
||||
inShape: []int{1, 1, 4},
|
||||
inShape: []int{4, 1, 1},
|
||||
seqs: []int{0, 0, 0, 0},
|
||||
pos: []int32{0, 1, 2, 3},
|
||||
expected: []float32{1, 2, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedShape: []int{4, 1, 1},
|
||||
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
@ -252,11 +252,11 @@ func TestCopy(t *testing.T) {
|
||||
{
|
||||
name: "Copy",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
inShape: []int{2, 1, 1},
|
||||
seqs: []int{1, 1},
|
||||
pos: []int32{3, 4},
|
||||
expected: []float32{1, 2, 3, 4, 5, 6},
|
||||
expectedShape: []int{1, 1, 6},
|
||||
expectedShape: []int{6, 1, 1},
|
||||
expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
|
||||
},
|
||||
}
|
||||
@ -365,6 +365,9 @@ func (c *testContext) MaxGraphNodes() int {
|
||||
|
||||
func (c *testContext) Close() {}
|
||||
|
||||
// TODO remove this before merging - temporary debugging aid
|
||||
func (c *testContext) Abort(t ml.Tensor) {}
|
||||
|
||||
type testTensor struct {
|
||||
dtype ml.DType
|
||||
elementSize int
|
||||
@ -378,7 +381,8 @@ func (t *testTensor) Dim(n int) int {
|
||||
|
||||
func (t *testTensor) Stride(n int) int {
|
||||
stride := t.elementSize
|
||||
for i := range n {
|
||||
// Reverse to mimic ggml's row order impl
|
||||
for i := len(t.shape) - 1; i > n; i-- {
|
||||
stride *= t.shape[i]
|
||||
}
|
||||
|
||||
@ -473,7 +477,7 @@ func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
func (t *testTensor) View(ctx ml.Context, offset int, shape []int, stride []int) ml.Tensor {
|
||||
offset /= t.elementSize
|
||||
|
||||
var s []int
|
||||
@ -481,8 +485,8 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
s = []int{shape[0]}
|
||||
case 5:
|
||||
s = []int{shape[0], shape[2], shape[4]}
|
||||
case 3:
|
||||
s = []int{shape[0], shape[1], shape[2]}
|
||||
default:
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
@ -496,7 +500,32 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
switch len(t.shape) {
|
||||
case 2:
|
||||
sh := make([]int, len(t.shape))
|
||||
size := 1
|
||||
for i := range sh {
|
||||
sh[i] = t.shape[shape[i]]
|
||||
size *= sh[i]
|
||||
}
|
||||
data := make([]float32, size)
|
||||
dn := 0
|
||||
for j := range t.shape[1] {
|
||||
for i := range t.shape[0] {
|
||||
offset := i*t.shape[1] + j
|
||||
data[dn] = t.data[offset]
|
||||
dn++
|
||||
}
|
||||
}
|
||||
return &testTensor{
|
||||
dtype: t.dtype,
|
||||
elementSize: t.elementSize,
|
||||
data: data,
|
||||
shape: sh,
|
||||
}
|
||||
default:
|
||||
panic("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
|
@ -97,6 +97,10 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
|
||||
// TODO - the (Tensor, error) return pattern makes this impossible to
|
||||
// one-line in cases where we need to pass a scalar into a function that
|
||||
// requires a Tensor leading to overly verbose impls. Consider a Must* API.
|
||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
|
||||
@ -113,6 +117,9 @@ type Context interface {
|
||||
|
||||
// Layer returns a context appropriate for creating intermediate tensors
|
||||
Layer(int) Context
|
||||
|
||||
// TODO remove this before merging - temporary debugging aid
|
||||
Abort(Tensor) // Evaluate the graph up to this point, retrieve the data from the tensor and dump it to a json file for comparison
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
@ -130,7 +137,7 @@ type Tensor interface {
|
||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
Softmax(ctx Context) Tensor // TODO axis parameter?
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
@ -145,7 +152,7 @@ type Tensor interface {
|
||||
SILU(ctx Context) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
View(ctx Context, offset int, shape, stride []int) Tensor
|
||||
Permute(ctx Context, shape ...int) Tensor
|
||||
Contiguous(ctx Context) Tensor
|
||||
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
|
||||
@ -298,3 +305,16 @@ const (
|
||||
DTypeQ40
|
||||
DTypeI32
|
||||
)
|
||||
|
||||
func (dt DType) String() string {
|
||||
switch dt {
|
||||
case DTypeF32:
|
||||
return "float32"
|
||||
case DTypeF16:
|
||||
return "float16"
|
||||
case DTypeI32:
|
||||
return "int32"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
@ -9,12 +9,15 @@ package ggml
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -24,10 +27,13 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
)
|
||||
|
||||
var rev = []C.int{3, 2, 1, 0}
|
||||
|
||||
func devices() []*C.struct_ggml_backend_device {
|
||||
ggml.OnceLoad()
|
||||
ds := make([]*C.struct_ggml_backend_device, C.ggml_backend_dev_count())
|
||||
@ -65,7 +71,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
slog.Info(
|
||||
"",
|
||||
"initializing GGML backend",
|
||||
"architecture", meta.KV().Architecture(),
|
||||
"file_type", meta.KV().FileType(),
|
||||
"name", meta.KV().String("general.name"),
|
||||
@ -396,7 +402,7 @@ func (b *Backend) Config() ml.Config {
|
||||
|
||||
func (b *Backend) Get(name string) ml.Tensor {
|
||||
if t, ok := b.tensors[name]; ok {
|
||||
return &Tensor{b: b, t: t}
|
||||
return &Tensor{b: b, t: t, nDims: int(C.ggml_n_dims(t))}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -529,7 +535,7 @@ func pad(length, pad C.size_t) C.size_t {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
func (c Context) newTensor(dtype ml.DType, rshape []int) ml.Tensor {
|
||||
if c.buft == nil {
|
||||
panic("set Input, Output, or Layer before creating tensors")
|
||||
}
|
||||
@ -550,24 +556,31 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
if len(rshape) < 1 || rshape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
||||
} else if len(shape) > 4 {
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape), nDims: 1}
|
||||
} else if len(rshape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
|
||||
for _, dim := range shape {
|
||||
for _, dim := range rshape {
|
||||
if dim < 1 {
|
||||
panic("invalid shape")
|
||||
}
|
||||
}
|
||||
// Inverted
|
||||
shape := make([]int, len(rshape))
|
||||
i := len(rshape) - 1
|
||||
for _, dim := range rshape {
|
||||
shape[i] = dim
|
||||
i--
|
||||
}
|
||||
|
||||
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
|
||||
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
return &Tensor{b: c.b, t: t}
|
||||
return &Tensor{b: c.b, t: t, nDims: len(shape)}
|
||||
}
|
||||
|
||||
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
@ -631,8 +644,14 @@ func (c *Context) Close() {
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
b *Backend
|
||||
t *C.struct_ggml_tensor
|
||||
b *Backend
|
||||
t *C.struct_ggml_tensor
|
||||
|
||||
// keep track of the number of dimensions
|
||||
// Since we reverse the shape, GGML considers a trailing "1" dimension as not present
|
||||
// and we can't actually trust the output of ggml_n_dims
|
||||
nDims int
|
||||
|
||||
sync func()
|
||||
}
|
||||
|
||||
@ -641,23 +660,35 @@ func (t *Tensor) LogValue() slog.Value {
|
||||
slog.String("name", C.GoString(C.ggml_get_name(t.t))),
|
||||
slog.String("type", C.GoString(C.ggml_type_name(t.t._type))),
|
||||
slog.Any("shape", t.Shape()),
|
||||
slog.Any("underlying shape", t.t.ne),
|
||||
slog.Any("underlying stride", t.t.nb),
|
||||
)
|
||||
}
|
||||
|
||||
func (t *Tensor) Dim(n int) int {
|
||||
return int(t.t.ne[n])
|
||||
if t.nDims == 0 {
|
||||
// If this hits we likely forgot to copy the dimension to the returned tensor in some operation
|
||||
panic("zero dimension tensor")
|
||||
}
|
||||
r := rev[4-t.nDims:]
|
||||
return int(t.t.ne[r[n]])
|
||||
}
|
||||
|
||||
func (t *Tensor) Stride(n int) int {
|
||||
return int(t.t.nb[n])
|
||||
if t.nDims == 0 {
|
||||
slog.Error("Stride", "tensor", t, "dim", n)
|
||||
panic("zero dimension tensor")
|
||||
}
|
||||
r := rev[4-t.nDims:]
|
||||
s := int(t.t.nb[r[n]])
|
||||
return s
|
||||
}
|
||||
|
||||
func (t *Tensor) Shape() []int {
|
||||
shape := make([]int, C.ggml_n_dims(t.t))
|
||||
shape := make([]int, t.nDims)
|
||||
for i := range shape {
|
||||
shape[i] = t.Dim(i)
|
||||
}
|
||||
|
||||
return shape
|
||||
}
|
||||
|
||||
@ -702,8 +733,9 @@ func (t *Tensor) DType() ml.DType {
|
||||
|
||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
b: t.b,
|
||||
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
@ -717,29 +749,38 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
|
||||
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
|
||||
b: t.b,
|
||||
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
|
||||
nDims: max(t.nDims, t2.(*Tensor).nDims),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
||||
b: t.b,
|
||||
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
b: t.b,
|
||||
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
nDims: t.nDims, // TODO should this be max(t.nDims, t2.nDims)?
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
if t.t.ne[0] != t2.(*Tensor).t.ne[0] {
|
||||
slog.Error("incorrect tensor shapes for Mulmat", "t", t, "t2", t2)
|
||||
panic("malformed tensors passed to Mulmat")
|
||||
}
|
||||
r := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
b: t.b,
|
||||
t: r,
|
||||
nDims: max(t.nDims, t2.(*Tensor).nDims),
|
||||
}
|
||||
}
|
||||
|
||||
@ -748,13 +789,14 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: mul,
|
||||
b: t.b,
|
||||
t: mul,
|
||||
nDims: max(t.nDims, t2.(*Tensor).nDims),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)), nDims: t.nDims}).Mul(ctx, w)
|
||||
if b != nil {
|
||||
tt = tt.Add(ctx, b)
|
||||
}
|
||||
@ -763,17 +805,28 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
|
||||
}
|
||||
|
||||
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps)), nDims: t.nDims}).Mul(ctx, w)
|
||||
}
|
||||
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
if len(shape) != 4 {
|
||||
panic("expected 4 dimensions")
|
||||
}
|
||||
|
||||
var r *C.struct_ggml_tensor
|
||||
switch t.nDims {
|
||||
case 1:
|
||||
r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3]))
|
||||
case 2:
|
||||
r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[1]), C.int(shape[0]), C.int(shape[2]), C.int(shape[3]))
|
||||
case 3:
|
||||
r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[2]), C.int(shape[1]), C.int(shape[0]), C.int(shape[3]))
|
||||
default:
|
||||
r = C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[3]), C.int(shape[2]), C.int(shape[1]), C.int(shape[0]))
|
||||
}
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
b: t.b,
|
||||
t: r,
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
@ -781,48 +834,132 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
if len(shape) != 4 {
|
||||
panic("expected 4 dimensions")
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
rshape := []C.int{0, 1, 2, 3}
|
||||
switch t.nDims {
|
||||
case 2:
|
||||
// TODO make sure this isn't wonky...
|
||||
rshape[0] = rev[2:][shape[1]]
|
||||
rshape[1] = rev[2:][shape[0]]
|
||||
case 3:
|
||||
// TODO has to be a better way...
|
||||
rshape[0] = C.int(shape[0])
|
||||
rshape[1] = C.int(shape[1])
|
||||
rshape[2] = C.int(shape[2])
|
||||
switch shape[0]*100 + shape[1]*10 + shape[2] {
|
||||
case 21:
|
||||
rshape[0], rshape[1], rshape[2] = 1, 0, 2
|
||||
case 102:
|
||||
rshape[0], rshape[1], rshape[2] = 0, 2, 1
|
||||
}
|
||||
case 4:
|
||||
// TODO has to be a better way...
|
||||
rshape[0] = C.int(shape[0])
|
||||
rshape[1] = C.int(shape[1])
|
||||
rshape[2] = C.int(shape[2])
|
||||
rshape[3] = C.int(shape[3])
|
||||
switch shape[0]*1000 + shape[1]*100 + shape[2]*10 + shape[3] {
|
||||
case 132:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 1, 0, 2, 3
|
||||
case 231:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 1, 2, 0, 3
|
||||
case 312:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 2, 0, 1, 3
|
||||
case 321:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 2, 1, 0, 3
|
||||
case 1023:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 0, 1, 3, 2
|
||||
case 1203:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 0, 2, 3, 1
|
||||
case 1302:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 2, 0, 3, 1
|
||||
case 1320:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 2, 1, 3, 0
|
||||
case 2013:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 0, 3, 1, 2
|
||||
case 2031:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 1, 3, 0, 2
|
||||
case 2103:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 0, 3, 2, 1
|
||||
case 2130:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 1, 3, 2, 0
|
||||
case 3021:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 3, 1, 0, 2
|
||||
case 3102:
|
||||
rshape[0], rshape[1], rshape[2], rshape[3] = 3, 0, 2, 1
|
||||
}
|
||||
}
|
||||
|
||||
r := &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_permute(ctx.(*Context).ctx, t.t, rshape[0], rshape[1], rshape[2], rshape[3]),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
b: t.b,
|
||||
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
r := &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
// GGML does not handle -1 natively
|
||||
for i, sh := range shape {
|
||||
if sh == -1 {
|
||||
totalElems := 1
|
||||
for d := range t.nDims {
|
||||
totalElems *= int(t.t.ne[d])
|
||||
}
|
||||
otherElems := 1
|
||||
for _, osh := range shape {
|
||||
if osh != -1 {
|
||||
otherElems *= osh
|
||||
}
|
||||
}
|
||||
if otherElems > totalElems {
|
||||
slog.Error("Invalid request", "req", shape, "actual", t.Shape(), "totalElems", totalElems, "otherElems", otherElems)
|
||||
panic("impossible -1 shape request")
|
||||
}
|
||||
shape[i] = int(float64(totalElems) / float64(otherElems))
|
||||
break
|
||||
}
|
||||
}
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
|
||||
nDims: len(shape),
|
||||
}
|
||||
case 2:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[1]), C.int64_t(shape[0])),
|
||||
nDims: len(shape),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0])),
|
||||
nDims: len(shape),
|
||||
}
|
||||
case 4:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[3]), C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0])),
|
||||
nDims: len(shape),
|
||||
}
|
||||
default:
|
||||
panic("unsupported number of dimensions")
|
||||
@ -831,22 +968,25 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
|
||||
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
|
||||
b: t.b,
|
||||
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
|
||||
b: t.b,
|
||||
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
@ -856,41 +996,50 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
b: t.b,
|
||||
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[3]), C.int(shape[2]), C.int(shape[1]), C.int(shape[0])),
|
||||
nDims: t.nDims, // TODO is this right?
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
func (t *Tensor) View(ctx ml.Context, offset int, shape, stride []int) ml.Tensor {
|
||||
if len(stride)+1 != len(shape) {
|
||||
panic(fmt.Sprintf("malformed view request: shape=%v stride=%v", shape, stride))
|
||||
}
|
||||
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
|
||||
nDims: 1,
|
||||
}
|
||||
case 2:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
|
||||
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[1]), C.int64_t(shape[0]),
|
||||
C.size_t(stride[0]),
|
||||
C.size_t(offset)),
|
||||
nDims: 2,
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]),
|
||||
C.size_t(shape[1]),
|
||||
C.size_t(offset)),
|
||||
}
|
||||
case 5:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]),
|
||||
C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0]),
|
||||
C.size_t(stride[1]), C.size_t(stride[0]),
|
||||
C.size_t(offset)),
|
||||
nDims: 3,
|
||||
}
|
||||
case 7:
|
||||
case 4:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
|
||||
C.int64_t(shape[3]), C.int64_t(shape[2]), C.int64_t(shape[1]), C.int64_t(shape[0]),
|
||||
C.size_t(stride[2]), C.size_t(stride[1]), C.size_t(stride[0]),
|
||||
C.size_t(offset)),
|
||||
nDims: 4,
|
||||
}
|
||||
default:
|
||||
panic("unsupported number of dimensions")
|
||||
@ -906,14 +1055,13 @@ const (
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
ropeFactors = &Tensor{b: t.b, nDims: 0}
|
||||
}
|
||||
|
||||
dequant := t.t
|
||||
if C.ggml_is_quantized(t.t._type) {
|
||||
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_rope_ext(
|
||||
@ -928,34 +1076,39 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
32., // YaRN beta_fast
|
||||
1., // YaRN beta_slow
|
||||
),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||
b: t.b,
|
||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
|
||||
b: t.b,
|
||||
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
|
||||
b: t.b,
|
||||
t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
}
|
||||
|
||||
@ -970,7 +1123,7 @@ func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) m
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
|
||||
return &Tensor{b: t.b, t: tt}
|
||||
return &Tensor{b: t.b, t: tt, nDims: t.nDims}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||
@ -979,23 +1132,60 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
|
||||
kqMask = mask.(*Tensor).t
|
||||
}
|
||||
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
query := t.Permute(ctx, 1, 0, 2, 3)
|
||||
key = key.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
if t.b.flashAttention {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 0, 2, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
return &Tensor{b: t.b, t: kqv, nDims: t.nDims}
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
nDims: t.nDims,
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
return kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO remove this before merging - temporary debugging aid
|
||||
func (c *Context) Abort(t ml.Tensor) {
|
||||
// Hack to make sure we're f32, otherwise r.Floats will fail due to short read
|
||||
if t.(*Tensor).t._type != C.GGML_TYPE_F32 {
|
||||
t.(*Tensor).t = C.ggml_cast(c.ctx, t.(*Tensor).t, C.GGML_TYPE_F32)
|
||||
}
|
||||
c.Forward(t)
|
||||
c.Compute(t)
|
||||
f32 := t.Floats()
|
||||
// Convert [-]Inf to serializable values
|
||||
for i, v := range f32 {
|
||||
if v > math.MaxFloat32 {
|
||||
f32[i] = math.MaxFloat32
|
||||
}
|
||||
if v < -math.SmallestNonzeroFloat32 {
|
||||
f32[i] = -math.MaxFloat32
|
||||
}
|
||||
}
|
||||
debug.PrintStack()
|
||||
|
||||
filename := "ggml.json"
|
||||
slog.Info("Writing tensors to", "filename", filename)
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer f.Close()
|
||||
encoder := json.NewEncoder(f)
|
||||
err = encoder.Encode(f32)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
os.Exit(1)
|
||||
}
|
||||
|
@ -12,9 +12,9 @@ import (
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tensor operations
|
||||
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - query: Query tensor (Q) with shape [seq_len_q, heads, d_k]
|
||||
// - key: Key tensor (K) with shape [seq_len_k, kv_heads, d_k], can be nil to read from cache only
|
||||
// - value: Value tensor (V) with shape [seq_len_k, kv_heads, d_v], can be nil to read from cache only
|
||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||
//
|
||||
@ -23,16 +23,16 @@ import (
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
if query.Dim(2) != key.Dim(2) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(2), key.Dim(2)))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
if key.Dim(0) != value.Dim(0) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(0), value.Dim(0)))
|
||||
}
|
||||
|
||||
if cache != nil {
|
||||
@ -52,8 +52,8 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
query = query.Permute(ctx, 1, 0, 2, 3)
|
||||
key = key.Permute(ctx, 1, 0, 2, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
@ -65,6 +65,6 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
return kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
62
model/README.md
Normal file
62
model/README.md
Normal file
@ -0,0 +1,62 @@
|
||||
# Ollama Models
|
||||
|
||||
Large Language Models in Ollama are defined in the Go programming language within this directory.
|
||||
|
||||
|
||||
## Model Implementation Guide
|
||||
|
||||
Ollama supports multiple backends, and provides an astracted interface for model implementers. [Backend API ](../ml/backend.go)
|
||||
|
||||
This API is designed to be similar to other popular python libraries such as
|
||||
PyTorch, with row-major tensors and a forward function that takes a sequence of
|
||||
inputs.
|
||||
|
||||
Use an existing model as an initial reference, such as [llama](./models/llama/)
|
||||
|
||||
Cheatsheet:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><b>PyTorch</b></td>
|
||||
<td><b>Ollama</b></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>torch.zeros((2, 2))</td>
|
||||
<td>ctx.Zeros(ml.DTypeF32, 2, 2)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>tensor.view((2, 2))</td>
|
||||
<td>t.Reshape(ctx, 2, 2)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>torch.permute(t1, (1, 2, 3))</td>
|
||||
<td>t1.Permute(ctx, 1, 2, 3)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>torch.add(t1, t2)</td>
|
||||
<td>t1.Add(ctx, t2)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
|
||||
```python
|
||||
class Attention(nn.Module):
|
||||
def __call__(self, ...):
|
||||
...
|
||||
```
|
||||
|
||||
</td>
|
||||
<td>
|
||||
|
||||
```go
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context,
|
||||
hiddenState, positionIDs ml.Tensor,
|
||||
cache kvcache.Cache,
|
||||
opts *Options) ml.Tensor {
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
@ -306,3 +306,14 @@ func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func ArangeF32(start, end, step float32) []float32 {
|
||||
if step == 0 || start >= end {
|
||||
return nil
|
||||
}
|
||||
var res []float32
|
||||
for i := float32(start); i < end; i += step {
|
||||
res = append(res, i)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
@ -77,11 +77,11 @@ type SelfAttention struct {
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
batchSize := hiddenState.Dim(0)
|
||||
ropeType := uint32(2)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
@ -91,17 +91,17 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen)
|
||||
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3)
|
||||
q = q.Permute(ctx, 1, 0, 2, 3)
|
||||
k = k.Permute(ctx, 1, 0, 2, 3)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := k.Mulmat(ctx, q)
|
||||
@ -115,8 +115,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
kqv = kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
kqv = kqv.Reshape(ctx, batchSize, opts.attnValLen*opts.numHeads)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
@ -35,15 +35,15 @@ type MultiModalProjector struct {
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
|
||||
l := visionOutputs.Dim(0)
|
||||
l := visionOutputs.Dim(1)
|
||||
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
patchesPerImage := imageSize / patchSize
|
||||
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, l, patchesPerImage, patchesPerImage)
|
||||
|
||||
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
|
||||
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, l, visionOutputs.Dim(2)*visionOutputs.Dim(1))
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
|
||||
|
||||
@ -98,9 +98,9 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -121,13 +121,13 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(0) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(0)-1)...)
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
|
@ -85,7 +85,7 @@ type TextSelfAttention struct {
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
batchSize := hiddenState.Dim(0)
|
||||
ropeType := uint32(2)
|
||||
|
||||
ropeBase := opts.ropeLocalBase
|
||||
@ -94,7 +94,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
|
||||
@ -105,16 +105,16 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen)
|
||||
|
||||
scaleFactor := 1.0
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
kqv = kqv.Reshape(ctx, batchSize, opts.attnValLen*opts.numHeads)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
@ -179,9 +179,9 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
var except []int
|
||||
for _, image := range opts.Multimodal {
|
||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(0), []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)}, nil)))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
for i := range visionOutputs.Dim(0) {
|
||||
except = append(except, image.Index+i)
|
||||
}
|
||||
}
|
||||
|
@ -7,8 +7,6 @@ import (
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int = 1
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
@ -23,12 +21,12 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
query = query.Reshape(ctx, query.Dim(0), opts.numHeads, headDim)
|
||||
key = key.Reshape(ctx, key.Dim(0), opts.numHeads, headDim)
|
||||
value = value.Reshape(ctx, value.Dim(0), opts.numHeads, headDim)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0), opts.hiddenSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
return hiddenState
|
||||
@ -88,7 +86,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
|
||||
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPatches)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
positions := make([]int32, numPatches)
|
||||
|
@ -74,24 +74,24 @@ type SelfAttention struct {
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
batchSize := hiddenState.Dim(0) // TODO Consider renaming "L" as this is the sequence length, not batch size
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.Reshape(ctx, batchSize, opts.numHeads, -1)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.Reshape(ctx, batchSize, opts.numKVHeads, -1)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
v = v.Reshape(ctx, batchSize, opts.numKVHeads, -1)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
kqv = kqv.Reshape(ctx, batchSize, -1)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
@ -78,10 +78,10 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
m.ImageProcessor.maxNumTiles,
|
||||
m.ImageProcessor.numChannels,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -18,24 +18,24 @@ type TextSelfAttention struct {
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
batchSize := hiddenState.Dim(0)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = query.Reshape(ctx, batchSize, opts.numHeads, headDim)
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = key.Reshape(ctx, batchSize, opts.numKVHeads, headDim)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, batchSize, opts.numKVHeads, headDim)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
attention = attention.Reshape(ctx, batchSize, opts.hiddenSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
@ -99,23 +99,23 @@ type TextCrossAttention struct {
|
||||
}
|
||||
|
||||
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
batchSize := hiddenState.Dim(0)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := ca.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = query.Reshape(ctx, batchSize, opts.numHeads, headDim)
|
||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
|
||||
var key, value ml.Tensor
|
||||
if crossAttentionStates != nil {
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(2), crossAttentionStates.Dim(1)
|
||||
|
||||
key = ca.Key.Forward(ctx, crossAttentionStates)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
key = key.Reshape(ctx, numVisionTokens*numTiles, opts.numKVHeads, headDim)
|
||||
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
value = ca.Value.Forward(ctx, crossAttentionStates)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
value = value.Reshape(ctx, numVisionTokens*numTiles, opts.numKVHeads, headDim)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
@ -124,8 +124,8 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
query = query.Permute(ctx, 1, 0, 2, 3)
|
||||
key = key.Permute(ctx, 1, 0, 2, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
@ -134,8 +134,8 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
attention := kqv.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, batchSize, opts.hiddenSize)
|
||||
|
||||
return ca.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
@ -23,25 +23,25 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
query = query.Reshape(ctx, batchSize, query.Dim(1), opts.numHeads, headDim)
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, batchSize, key.Dim(1), opts.numHeads, headDim)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
value = value.Reshape(ctx, batchSize, value.Dim(1), opts.numHeads, headDim)
|
||||
value = value.Permute(ctx, 0, 2, 3, 1).Contiguous(ctx)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
|
||||
attention = attention.Reshape(ctx, batchSize, opts.numHeads, attention.Dim(2), headDim)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
attention = attention.Reshape(ctx, batchSize, attention.Dim(1), opts.hiddenSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
if sa.Gate != nil {
|
||||
@ -99,7 +99,7 @@ func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermedi
|
||||
var intermediateHiddenStates []ml.Tensor
|
||||
for i, layer := range e.Layers {
|
||||
if slices.Contains(intermediateLayersIndices, uint32(i)) {
|
||||
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
|
||||
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append(hiddenState.Shape(), 1)...))
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, opts)
|
||||
@ -115,7 +115,7 @@ type PrecomputedAspectRatioEmbedding struct {
|
||||
|
||||
func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
|
||||
embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
|
||||
embeddings = embeddings.Reshape(ctx, opts.numTiles, 1, opts.hiddenSize)
|
||||
if e.Gate != nil {
|
||||
embeddings = embeddings.Mul(ctx, e.Gate)
|
||||
}
|
||||
@ -140,7 +140,7 @@ func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, posi
|
||||
hiddenState = hiddenState.Add(ctx, positionEmbedding)
|
||||
|
||||
tilePositionEmbedding := e.TilePositionEmbedding.Forward(ctx, aspectRatioIDs)
|
||||
tilePositionEmbedding = tilePositionEmbedding.Reshape(ctx, opts.hiddenSize, numPositions, opts.numTiles)
|
||||
tilePositionEmbedding = tilePositionEmbedding.Reshape(ctx, opts.numTiles, numPositions, opts.hiddenSize)
|
||||
if e.TilePositionEmbeddingGate != nil {
|
||||
tilePositionEmbedding = tilePositionEmbedding.Mul(ctx, e.TilePositionEmbeddingGate)
|
||||
}
|
||||
@ -181,8 +181,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
|
||||
}
|
||||
|
||||
hiddenState := m.PatchEmbeddings.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize, m.numTiles)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenState = hiddenState.Reshape(ctx, m.numTiles, m.hiddenSize, numPatches)
|
||||
hiddenState = hiddenState.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
|
||||
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1)
|
||||
@ -193,23 +193,23 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
|
||||
numPaddingPatches := 8 - (hiddenState.Dim(1)%8)%8
|
||||
hiddenState = hiddenState.Pad(ctx, 0, numPaddingPatches, 0, 0)
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), hiddenState.Dim(1)*hiddenState.Dim(2), batchSize)
|
||||
hiddenState = hiddenState.Reshape(ctx, batchSize, hiddenState.Dim(1)*hiddenState.Dim(0), hiddenState.Dim(2))
|
||||
hiddenState, intermediateHiddenStates := m.Transformer.Forward(ctx, hiddenState, m.intermediateLayersIndices, m.VisionModelOptions)
|
||||
|
||||
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
|
||||
hiddenState = hiddenState.Reshape(ctx, batchSize, m.numTiles, numPositions+numPaddingPatches, m.hiddenSize)
|
||||
hiddenState = m.PostTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, m.numTiles*(numPositions+numPaddingPatches), batchSize)
|
||||
hiddenState = hiddenState.Reshape(ctx, batchSize, m.numTiles*(numPositions+numPaddingPatches), m.hiddenSize)
|
||||
hiddenState, _ = m.GlobalTransformer.Forward(ctx, hiddenState, nil, m.VisionModelOptions)
|
||||
|
||||
hiddenStates := intermediateHiddenStates[0].Stack(ctx, 0, intermediateHiddenStates[1:]...)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, len(intermediateHiddenStates)*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
|
||||
hiddenStates = hiddenStates.Unpad(ctx, 0, numPaddingPatches, 0, 0)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, batchSize, m.numTiles, numPositions+numPaddingPatches, len(intermediateHiddenStates)*m.hiddenSize)
|
||||
hiddenStates = hiddenStates.Unpad(ctx, 0, 0, numPaddingPatches, 0)
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
|
||||
hiddenState = hiddenState.Unpad(ctx, 0, numPaddingPatches, 0, 0)
|
||||
hiddenState = hiddenState.Reshape(ctx, batchSize, m.numTiles, numPositions+numPaddingPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Unpad(ctx, 0, 0, numPaddingPatches, 0)
|
||||
return hiddenState.Concat(ctx, hiddenStates, 0)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user