mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 23:17:59 +01:00
tests: add tests and docs for commonly used ops (#12844)
* mulmat * permute
This commit is contained in:
@@ -1231,6 +1231,11 @@ func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mulmat performs matrix multiplication between two tensors.
|
||||||
|
// If t has shape [m, p, ...] and t2 has shape [m, n, ...],
|
||||||
|
// Mulmat returns a new Tensor with shape [p, n, ...].
|
||||||
|
//
|
||||||
|
// Note: this is similar to matmul(t2, t.tranpose(-1, -2)) in other libraries.
|
||||||
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
@@ -1303,14 +1308,21 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
// Permute permutes t according to order. Permute panics if the number of dimensions
|
||||||
if len(shape) != 4 {
|
// in order does not match the number of dimensions in t.
|
||||||
panic("expected 4 dimensions")
|
func (t *Tensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
|
||||||
|
if len(order) != len(t.Shape()) && len(order) != 4 {
|
||||||
|
panic("invalid number of dimensions for permute")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_permute requires 4 dimensions so fill in the rest
|
||||||
|
for i := len(order); i < 4; i++ {
|
||||||
|
order = append(order, i)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(order[0]), C.int(order[1]), C.int(order[2]), C.int(order[3])),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func setup(tb testing.TB) ml.Context {
|
|||||||
tb.Fatal(err)
|
tb.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := ml.NewBackend(f.Name(), ml.BackendParams{})
|
b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tb.Fatal(err)
|
tb.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -124,3 +124,254 @@ func TestInferShape(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EquateTensors(ctx ml.Context) cmp.Option {
|
||||||
|
return cmp.Comparer(func(x, y ml.Tensor) bool {
|
||||||
|
ctx.Forward(x, y).Compute(x, y)
|
||||||
|
return cmp.Equal(x.Shape(), y.Shape()) &&
|
||||||
|
cmp.Equal(x.DType(), y.DType()) &&
|
||||||
|
cmp.Equal(x.Bytes(), y.Bytes())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMulmat(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
a, b, c func(ml.Context) ml.Tensor
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "vector x vector",
|
||||||
|
a: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 3, 1, ml.DTypeF32)
|
||||||
|
},
|
||||||
|
b: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 3, 1, ml.DTypeF32)
|
||||||
|
},
|
||||||
|
c: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{5}, 1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "vector x matrix",
|
||||||
|
a: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4, 1, ml.DTypeF32)
|
||||||
|
},
|
||||||
|
b: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 4, 3)
|
||||||
|
},
|
||||||
|
c: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
14, 38, 62,
|
||||||
|
}, 1, 3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "broadcast vector x batched matrix",
|
||||||
|
a: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4, 1, ml.DTypeF32)
|
||||||
|
},
|
||||||
|
b: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 10*3*4, 1, ml.DTypeF32).Reshape(ctx, 4, 3, 10)
|
||||||
|
},
|
||||||
|
c: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
14, 38, 62,
|
||||||
|
86, 110, 134,
|
||||||
|
158, 182, 206,
|
||||||
|
230, 254, 278,
|
||||||
|
302, 326, 350,
|
||||||
|
374, 398, 422,
|
||||||
|
446, 470, 494,
|
||||||
|
518, 542, 566,
|
||||||
|
590, 614, 638,
|
||||||
|
662, 686, 710,
|
||||||
|
}, 1, 3, 10)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "batched matrix x batched matrix",
|
||||||
|
a: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*5*10, 1, ml.DTypeF32).Reshape(ctx, 4, 5, 10)
|
||||||
|
},
|
||||||
|
b: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*3*10, 1, ml.DTypeF32).Reshape(ctx, 4, 3, 10)
|
||||||
|
},
|
||||||
|
c: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
14, 38, 62, 86, 110,
|
||||||
|
38, 126, 214, 302, 390,
|
||||||
|
62, 214, 366, 518, 670,
|
||||||
|
|
||||||
|
1166, 1382, 1598, 1814, 2030,
|
||||||
|
1510, 1790, 2070, 2350, 2630,
|
||||||
|
1854, 2198, 2542, 2886, 3230,
|
||||||
|
|
||||||
|
4238, 4646, 5054, 5462, 5870,
|
||||||
|
4902, 5374, 5846, 6318, 6790,
|
||||||
|
5566, 6102, 6638, 7174, 7710,
|
||||||
|
|
||||||
|
9230, 9830, 10430, 11030, 11630,
|
||||||
|
10214, 10878, 11542, 12206, 12870,
|
||||||
|
11198, 11926, 12654, 13382, 14110,
|
||||||
|
|
||||||
|
16142, 16934, 17726, 18518, 19310,
|
||||||
|
17446, 18302, 19158, 20014, 20870,
|
||||||
|
18750, 19670, 20590, 21510, 22430,
|
||||||
|
|
||||||
|
24974, 25958, 26942, 27926, 28910,
|
||||||
|
26598, 27646, 28694, 29742, 30790,
|
||||||
|
28222, 29334, 30446, 31558, 32670,
|
||||||
|
|
||||||
|
35726, 36902, 38078, 39254, 40430,
|
||||||
|
37670, 38910, 40150, 41390, 42630,
|
||||||
|
39614, 40918, 42222, 43526, 44830,
|
||||||
|
|
||||||
|
48398, 49766, 51134, 52502, 53870,
|
||||||
|
50662, 52094, 53526, 54958, 56390,
|
||||||
|
52926, 54422, 55918, 57414, 58910,
|
||||||
|
|
||||||
|
62990, 64550, 66110, 67670, 69230,
|
||||||
|
65574, 67198, 68822, 70446, 72070,
|
||||||
|
68158, 69846, 71534, 73222, 74910,
|
||||||
|
|
||||||
|
79502, 81254, 83006, 84758, 86510,
|
||||||
|
82406, 84222, 86038, 87854, 89670,
|
||||||
|
85310, 87190, 89070, 90950, 92830,
|
||||||
|
}, 5, 3, 10)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "broadcast matrix x batched matrix",
|
||||||
|
a: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*5, 1, ml.DTypeF32).Reshape(ctx, 4, 5)
|
||||||
|
},
|
||||||
|
b: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*3*10, 1, ml.DTypeF32).Reshape(ctx, 4, 3, 10)
|
||||||
|
},
|
||||||
|
c: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
14, 38, 62, 86, 110,
|
||||||
|
38, 126, 214, 302, 390,
|
||||||
|
62, 214, 366, 518, 670,
|
||||||
|
|
||||||
|
86, 302, 518, 734, 950,
|
||||||
|
110, 390, 670, 950, 1230,
|
||||||
|
134, 478, 822, 1166, 1510,
|
||||||
|
|
||||||
|
158, 566, 974, 1382, 1790,
|
||||||
|
182, 654, 1126, 1598, 2070,
|
||||||
|
206, 742, 1278, 1814, 2350,
|
||||||
|
|
||||||
|
230, 830, 1430, 2030, 2630,
|
||||||
|
254, 918, 1582, 2246, 2910,
|
||||||
|
278, 1006, 1734, 2462, 3190,
|
||||||
|
|
||||||
|
302, 1094, 1886, 2678, 3470,
|
||||||
|
326, 1182, 2038, 2894, 3750,
|
||||||
|
350, 1270, 2190, 3110, 4030,
|
||||||
|
|
||||||
|
374, 1358, 2342, 3326, 4310,
|
||||||
|
398, 1446, 2494, 3542, 4590,
|
||||||
|
422, 1534, 2646, 3758, 4870,
|
||||||
|
|
||||||
|
446, 1622, 2798, 3974, 5150,
|
||||||
|
470, 1710, 2950, 4190, 5430,
|
||||||
|
494, 1798, 3102, 4406, 5710,
|
||||||
|
|
||||||
|
518, 1886, 3254, 4622, 5990,
|
||||||
|
542, 1974, 3406, 4838, 6270,
|
||||||
|
566, 2062, 3558, 5054, 6550,
|
||||||
|
|
||||||
|
590, 2150, 3710, 5270, 6830,
|
||||||
|
614, 2238, 3862, 5486, 7110,
|
||||||
|
638, 2326, 4014, 5702, 7390,
|
||||||
|
|
||||||
|
662, 2414, 4166, 5918, 7670,
|
||||||
|
686, 2502, 4318, 6134, 7950,
|
||||||
|
710, 2590, 4470, 6350, 8230,
|
||||||
|
}, 5, 3, 10)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := setup(t)
|
||||||
|
a, b := tt.a(ctx), tt.b(ctx)
|
||||||
|
c := a.Mulmat(ctx, b)
|
||||||
|
if diff := cmp.Diff(tt.c(ctx), c, EquateTensors(ctx)); diff != "" {
|
||||||
|
t.Errorf("MulMat() result mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermute(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
input func(ml.Context) ml.Tensor
|
||||||
|
shape []int
|
||||||
|
want func(ml.Context) ml.Tensor
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "transpose",
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 3*2, 1, ml.DTypeF32).Reshape(ctx, 3, 2)
|
||||||
|
},
|
||||||
|
shape: []int{1, 0, 2, 3},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 3,
|
||||||
|
1, 4,
|
||||||
|
2, 5,
|
||||||
|
}, 2, 3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "transpose fill dims",
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 3*2, 1, ml.DTypeF32).Reshape(ctx, 3, 2)
|
||||||
|
},
|
||||||
|
shape: []int{1, 0},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 3,
|
||||||
|
1, 4,
|
||||||
|
2, 5,
|
||||||
|
}, 2, 3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "permute 3d",
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 5*3*2, 1, ml.DTypeF32).Reshape(ctx, 2, 3, 5)
|
||||||
|
},
|
||||||
|
shape: []int{2, 0, 1, 3},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 2, 4,
|
||||||
|
6, 8, 10,
|
||||||
|
12, 14, 16,
|
||||||
|
18, 20, 22,
|
||||||
|
24, 26, 28,
|
||||||
|
|
||||||
|
1, 3, 5,
|
||||||
|
7, 9, 11,
|
||||||
|
13, 15, 17,
|
||||||
|
19, 21, 23,
|
||||||
|
25, 27, 29,
|
||||||
|
}, 3, 5, 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := setup(t)
|
||||||
|
got := tt.input(ctx).Permute(ctx, tt.shape...).Contiguous(ctx)
|
||||||
|
if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" {
|
||||||
|
t.Errorf("Permute() result mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user