mirror of
https://github.com/ollama/ollama.git
synced 2025-12-10 09:32:19 +01:00
@@ -198,6 +198,10 @@ type Tensor interface {
|
|||||||
Copy(ctx Context, t2 Tensor) Tensor
|
Copy(ctx Context, t2 Tensor) Tensor
|
||||||
Duplicate(ctx Context) Tensor
|
Duplicate(ctx Context) Tensor
|
||||||
|
|
||||||
|
Slice(ctx Context, dim, low, high, step int) Tensor
|
||||||
|
Chunk(ctx Context, dim int, size int) []Tensor
|
||||||
|
ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||||
|
|
||||||
TopK(ctx Context, k int) Tensor
|
TopK(ctx Context, k int) Tensor
|
||||||
Argsort(ctx Context) Tensor
|
Argsort(ctx Context) Tensor
|
||||||
Mean(ctx Context) Tensor
|
Mean(ctx Context) Tensor
|
||||||
|
|||||||
@@ -1738,3 +1738,66 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
|||||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
|
||||||
|
// Slice panics if the dimension is invalid or the slice parameters are out of range.
|
||||||
|
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
|
||||||
|
func (t *Tensor) Slice(ctx ml.Context, dim int, low, high, step int) ml.Tensor {
|
||||||
|
if dim < 0 || dim >= C.GGML_MAX_DIMS {
|
||||||
|
panic("invalid dimension")
|
||||||
|
} else if low < 0 || high > t.Dim(dim) || low >= high || step < 1 {
|
||||||
|
panic("invalid slice parameters")
|
||||||
|
}
|
||||||
|
|
||||||
|
if dim == 0 && step > 1 {
|
||||||
|
// dim=0,step>1 is a special case so handle it here first
|
||||||
|
return t.View(ctx,
|
||||||
|
low*t.Stride(0), 1,
|
||||||
|
step*t.Stride(0), (high-low+1)/step,
|
||||||
|
t.Stride(1), t.Dim(1),
|
||||||
|
// preserve dim 3 by merging it into dim 2
|
||||||
|
t.Stride(2), t.Dim(2)*t.Dim(3),
|
||||||
|
).Contiguous(ctx, (high-low+1)/step, t.Dim(1), t.Dim(2), t.Dim(3))
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []int{
|
||||||
|
low * t.Stride(dim), t.Dim(0),
|
||||||
|
t.Stride(1), t.Dim(1),
|
||||||
|
t.Stride(2), t.Dim(2),
|
||||||
|
t.Stride(3), t.Dim(3),
|
||||||
|
}
|
||||||
|
|
||||||
|
if step == 1 {
|
||||||
|
args[dim*2+1] = high - low
|
||||||
|
return t.View(ctx, args[0], args[1:]...)
|
||||||
|
} else {
|
||||||
|
args[dim*2] = step * t.Stride(dim)
|
||||||
|
args[dim*2+1] = (high - low + 1) / step
|
||||||
|
return t.View(ctx, args[0], args[1:]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chunk the tensor into chunk sized tensors along dim. Each sub-tensor is a view of
|
||||||
|
// the original.
|
||||||
|
func (t *Tensor) Chunk(ctx ml.Context, dim, chunk int) []ml.Tensor {
|
||||||
|
sections := make([]int, 0, t.Dim(dim)/chunk+1)
|
||||||
|
for rest := t.Dim(dim); rest > 0; rest -= chunk {
|
||||||
|
sections = append(sections, min(chunk, rest))
|
||||||
|
}
|
||||||
|
return t.ChunkSections(ctx, dim, sections...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChunkSections split the tensor into section sized tensors along dim. Each sub-tensor is a
|
||||||
|
// view of the original. The size of the dim must equal the sum of sections.
|
||||||
|
func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Tensor {
|
||||||
|
var offset int
|
||||||
|
s := make([]ml.Tensor, len(sections))
|
||||||
|
for i, section := range sections {
|
||||||
|
s[i] = t.Slice(ctx, dim, offset, offset+section, 1)
|
||||||
|
offset += section
|
||||||
|
}
|
||||||
|
if offset != t.Dim(dim) {
|
||||||
|
panic("sections do not sum to tensor dimension")
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ggml
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -368,10 +369,714 @@ func TestPermute(t *testing.T) {
|
|||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := setup(t)
|
ctx := setup(t)
|
||||||
got := tt.input(ctx).Permute(ctx, tt.shape...).Contiguous(ctx)
|
got := tt.input(ctx).Permute(ctx, tt.shape...)
|
||||||
|
got = got.Contiguous(ctx)
|
||||||
if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" {
|
if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" {
|
||||||
t.Errorf("Permute() result mismatch (-want +got):\n%s", diff)
|
t.Errorf("Permute() result mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSlice(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
dim int
|
||||||
|
low int
|
||||||
|
high int
|
||||||
|
step int
|
||||||
|
input func(ml.Context) ml.Tensor
|
||||||
|
want func(ml.Context) ml.Tensor
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
dim: 0, low: 1, high: 3, step: 1,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4)
|
||||||
|
},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
1, 2,
|
||||||
|
5, 6,
|
||||||
|
9, 10,
|
||||||
|
13, 14,
|
||||||
|
|
||||||
|
17, 18,
|
||||||
|
21, 22,
|
||||||
|
25, 26,
|
||||||
|
29, 30,
|
||||||
|
|
||||||
|
33, 34,
|
||||||
|
37, 38,
|
||||||
|
41, 42,
|
||||||
|
45, 46,
|
||||||
|
|
||||||
|
49, 50,
|
||||||
|
53, 54,
|
||||||
|
57, 58,
|
||||||
|
61, 62,
|
||||||
|
|
||||||
|
65, 66,
|
||||||
|
69, 70,
|
||||||
|
73, 74,
|
||||||
|
77, 78,
|
||||||
|
|
||||||
|
81, 82,
|
||||||
|
85, 86,
|
||||||
|
89, 90,
|
||||||
|
93, 94,
|
||||||
|
|
||||||
|
97, 98,
|
||||||
|
101, 102,
|
||||||
|
105, 106,
|
||||||
|
109, 110,
|
||||||
|
|
||||||
|
113, 114,
|
||||||
|
117, 118,
|
||||||
|
121, 122,
|
||||||
|
125, 126,
|
||||||
|
|
||||||
|
129, 130,
|
||||||
|
133, 134,
|
||||||
|
137, 138,
|
||||||
|
141, 142,
|
||||||
|
|
||||||
|
145, 146,
|
||||||
|
149, 150,
|
||||||
|
153, 154,
|
||||||
|
157, 158,
|
||||||
|
|
||||||
|
161, 162,
|
||||||
|
165, 166,
|
||||||
|
169, 170,
|
||||||
|
173, 174,
|
||||||
|
|
||||||
|
177, 178,
|
||||||
|
181, 182,
|
||||||
|
185, 186,
|
||||||
|
189, 190,
|
||||||
|
|
||||||
|
193, 194,
|
||||||
|
197, 198,
|
||||||
|
201, 202,
|
||||||
|
205, 206,
|
||||||
|
|
||||||
|
209, 210,
|
||||||
|
213, 214,
|
||||||
|
217, 218,
|
||||||
|
221, 222,
|
||||||
|
|
||||||
|
225, 226,
|
||||||
|
229, 230,
|
||||||
|
233, 234,
|
||||||
|
237, 238,
|
||||||
|
|
||||||
|
241, 242,
|
||||||
|
245, 246,
|
||||||
|
249, 250,
|
||||||
|
253, 254,
|
||||||
|
}, 2, 4, 4, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 1, low: 1, high: 3, step: 1,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4)
|
||||||
|
},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
4, 5, 6, 7,
|
||||||
|
8, 9, 10, 11,
|
||||||
|
|
||||||
|
20, 21, 22, 23,
|
||||||
|
24, 25, 26, 27,
|
||||||
|
|
||||||
|
36, 37, 38, 39,
|
||||||
|
40, 41, 42, 43,
|
||||||
|
|
||||||
|
52, 53, 54, 55,
|
||||||
|
56, 57, 58, 59,
|
||||||
|
|
||||||
|
68, 69, 70, 71,
|
||||||
|
72, 73, 74, 75,
|
||||||
|
|
||||||
|
84, 85, 86, 87,
|
||||||
|
88, 89, 90, 91,
|
||||||
|
|
||||||
|
100, 101, 102, 103,
|
||||||
|
104, 105, 106, 107,
|
||||||
|
|
||||||
|
116, 117, 118, 119,
|
||||||
|
120, 121, 122, 123,
|
||||||
|
|
||||||
|
132, 133, 134, 135,
|
||||||
|
136, 137, 138, 139,
|
||||||
|
|
||||||
|
148, 149, 150, 151,
|
||||||
|
152, 153, 154, 155,
|
||||||
|
|
||||||
|
164, 165, 166, 167,
|
||||||
|
168, 169, 170, 171,
|
||||||
|
|
||||||
|
180, 181, 182, 183,
|
||||||
|
184, 185, 186, 187,
|
||||||
|
|
||||||
|
196, 197, 198, 199,
|
||||||
|
200, 201, 202, 203,
|
||||||
|
|
||||||
|
212, 213, 214, 215,
|
||||||
|
216, 217, 218, 219,
|
||||||
|
|
||||||
|
228, 229, 230, 231,
|
||||||
|
232, 233, 234, 235,
|
||||||
|
|
||||||
|
244, 245, 246, 247,
|
||||||
|
248, 249, 250, 251,
|
||||||
|
}, 4, 2, 4, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 2, low: 1, high: 3, step: 1,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4)
|
||||||
|
},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
16, 17, 18, 19,
|
||||||
|
20, 21, 22, 23,
|
||||||
|
24, 25, 26, 27,
|
||||||
|
28, 29, 30, 31,
|
||||||
|
|
||||||
|
32, 33, 34, 35,
|
||||||
|
36, 37, 38, 39,
|
||||||
|
40, 41, 42, 43,
|
||||||
|
44, 45, 46, 47,
|
||||||
|
|
||||||
|
80, 81, 82, 83,
|
||||||
|
84, 85, 86, 87,
|
||||||
|
88, 89, 90, 91,
|
||||||
|
92, 93, 94, 95,
|
||||||
|
|
||||||
|
96, 97, 98, 99,
|
||||||
|
100, 101, 102, 103,
|
||||||
|
104, 105, 106, 107,
|
||||||
|
108, 109, 110, 111,
|
||||||
|
|
||||||
|
144, 145, 146, 147,
|
||||||
|
148, 149, 150, 151,
|
||||||
|
152, 153, 154, 155,
|
||||||
|
156, 157, 158, 159,
|
||||||
|
|
||||||
|
160, 161, 162, 163,
|
||||||
|
164, 165, 166, 167,
|
||||||
|
168, 169, 170, 171,
|
||||||
|
172, 173, 174, 175,
|
||||||
|
|
||||||
|
208, 209, 210, 211,
|
||||||
|
212, 213, 214, 215,
|
||||||
|
216, 217, 218, 219,
|
||||||
|
220, 221, 222, 223,
|
||||||
|
|
||||||
|
224, 225, 226, 227,
|
||||||
|
228, 229, 230, 231,
|
||||||
|
232, 233, 234, 235,
|
||||||
|
236, 237, 238, 239,
|
||||||
|
}, 4, 4, 2, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 3, low: 1, high: 3, step: 1,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4)
|
||||||
|
},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
64, 65, 66, 67,
|
||||||
|
68, 69, 70, 71,
|
||||||
|
72, 73, 74, 75,
|
||||||
|
76, 77, 78, 79,
|
||||||
|
|
||||||
|
80, 81, 82, 83,
|
||||||
|
84, 85, 86, 87,
|
||||||
|
88, 89, 90, 91,
|
||||||
|
92, 93, 94, 95,
|
||||||
|
|
||||||
|
96, 97, 98, 99,
|
||||||
|
100, 101, 102, 103,
|
||||||
|
104, 105, 106, 107,
|
||||||
|
108, 109, 110, 111,
|
||||||
|
|
||||||
|
112, 113, 114, 115,
|
||||||
|
116, 117, 118, 119,
|
||||||
|
120, 121, 122, 123,
|
||||||
|
124, 125, 126, 127,
|
||||||
|
|
||||||
|
128, 129, 130, 131,
|
||||||
|
132, 133, 134, 135,
|
||||||
|
136, 137, 138, 139,
|
||||||
|
140, 141, 142, 143,
|
||||||
|
|
||||||
|
144, 145, 146, 147,
|
||||||
|
148, 149, 150, 151,
|
||||||
|
152, 153, 154, 155,
|
||||||
|
156, 157, 158, 159,
|
||||||
|
|
||||||
|
160, 161, 162, 163,
|
||||||
|
164, 165, 166, 167,
|
||||||
|
168, 169, 170, 171,
|
||||||
|
172, 173, 174, 175,
|
||||||
|
|
||||||
|
176, 177, 178, 179,
|
||||||
|
180, 181, 182, 183,
|
||||||
|
184, 185, 186, 187,
|
||||||
|
188, 189, 190, 191,
|
||||||
|
}, 4, 4, 4, 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 0, low: 0, high: 4, step: 2,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4)
|
||||||
|
},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 2,
|
||||||
|
4, 6,
|
||||||
|
8, 10,
|
||||||
|
12, 14,
|
||||||
|
|
||||||
|
16, 18,
|
||||||
|
20, 22,
|
||||||
|
24, 26,
|
||||||
|
28, 30,
|
||||||
|
|
||||||
|
32, 34,
|
||||||
|
36, 38,
|
||||||
|
40, 42,
|
||||||
|
44, 46,
|
||||||
|
|
||||||
|
48, 50,
|
||||||
|
52, 54,
|
||||||
|
56, 58,
|
||||||
|
60, 62,
|
||||||
|
|
||||||
|
64, 66,
|
||||||
|
68, 70,
|
||||||
|
72, 74,
|
||||||
|
76, 78,
|
||||||
|
|
||||||
|
80, 82,
|
||||||
|
84, 86,
|
||||||
|
88, 90,
|
||||||
|
92, 94,
|
||||||
|
|
||||||
|
96, 98,
|
||||||
|
100, 102,
|
||||||
|
104, 106,
|
||||||
|
108, 110,
|
||||||
|
|
||||||
|
112, 114,
|
||||||
|
116, 118,
|
||||||
|
120, 122,
|
||||||
|
124, 126,
|
||||||
|
|
||||||
|
128, 130,
|
||||||
|
132, 134,
|
||||||
|
136, 138,
|
||||||
|
140, 142,
|
||||||
|
|
||||||
|
144, 146,
|
||||||
|
148, 150,
|
||||||
|
152, 154,
|
||||||
|
156, 158,
|
||||||
|
|
||||||
|
160, 162,
|
||||||
|
164, 166,
|
||||||
|
168, 170,
|
||||||
|
172, 174,
|
||||||
|
|
||||||
|
176, 178,
|
||||||
|
180, 182,
|
||||||
|
184, 186,
|
||||||
|
188, 190,
|
||||||
|
|
||||||
|
192, 194,
|
||||||
|
196, 198,
|
||||||
|
200, 202,
|
||||||
|
204, 206,
|
||||||
|
|
||||||
|
208, 210,
|
||||||
|
212, 214,
|
||||||
|
216, 218,
|
||||||
|
220, 222,
|
||||||
|
|
||||||
|
224, 226,
|
||||||
|
228, 230,
|
||||||
|
232, 234,
|
||||||
|
236, 238,
|
||||||
|
|
||||||
|
240, 242,
|
||||||
|
244, 246,
|
||||||
|
248, 250,
|
||||||
|
252, 254,
|
||||||
|
}, 2, 4, 4, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 1, low: 0, high: 4, step: 2,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4)
|
||||||
|
},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 1, 2, 3,
|
||||||
|
8, 9, 10, 11,
|
||||||
|
|
||||||
|
16, 17, 18, 19,
|
||||||
|
24, 25, 26, 27,
|
||||||
|
|
||||||
|
32, 33, 34, 35,
|
||||||
|
40, 41, 42, 43,
|
||||||
|
|
||||||
|
48, 49, 50, 51,
|
||||||
|
56, 57, 58, 59,
|
||||||
|
|
||||||
|
64, 65, 66, 67,
|
||||||
|
72, 73, 74, 75,
|
||||||
|
|
||||||
|
80, 81, 82, 83,
|
||||||
|
88, 89, 90, 91,
|
||||||
|
|
||||||
|
96, 97, 98, 99,
|
||||||
|
104, 105, 106, 107,
|
||||||
|
|
||||||
|
112, 113, 114, 115,
|
||||||
|
120, 121, 122, 123,
|
||||||
|
|
||||||
|
128, 129, 130, 131,
|
||||||
|
136, 137, 138, 139,
|
||||||
|
|
||||||
|
144, 145, 146, 147,
|
||||||
|
152, 153, 154, 155,
|
||||||
|
|
||||||
|
160, 161, 162, 163,
|
||||||
|
168, 169, 170, 171,
|
||||||
|
|
||||||
|
176, 177, 178, 179,
|
||||||
|
184, 185, 186, 187,
|
||||||
|
|
||||||
|
192, 193, 194, 195,
|
||||||
|
200, 201, 202, 203,
|
||||||
|
|
||||||
|
208, 209, 210, 211,
|
||||||
|
216, 217, 218, 219,
|
||||||
|
|
||||||
|
224, 225, 226, 227,
|
||||||
|
232, 233, 234, 235,
|
||||||
|
|
||||||
|
240, 241, 242, 243,
|
||||||
|
248, 249, 250, 251,
|
||||||
|
}, 4, 2, 4, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 2, low: 0, high: 4, step: 2,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4)
|
||||||
|
},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 1, 2, 3,
|
||||||
|
4, 5, 6, 7,
|
||||||
|
8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15,
|
||||||
|
|
||||||
|
32, 33, 34, 35,
|
||||||
|
36, 37, 38, 39,
|
||||||
|
40, 41, 42, 43,
|
||||||
|
44, 45, 46, 47,
|
||||||
|
|
||||||
|
64, 65, 66, 67,
|
||||||
|
68, 69, 70, 71,
|
||||||
|
72, 73, 74, 75,
|
||||||
|
76, 77, 78, 79,
|
||||||
|
|
||||||
|
96, 97, 98, 99,
|
||||||
|
100, 101, 102, 103,
|
||||||
|
104, 105, 106, 107,
|
||||||
|
108, 109, 110, 111,
|
||||||
|
|
||||||
|
128, 129, 130, 131,
|
||||||
|
132, 133, 134, 135,
|
||||||
|
136, 137, 138, 139,
|
||||||
|
140, 141, 142, 143,
|
||||||
|
|
||||||
|
160, 161, 162, 163,
|
||||||
|
164, 165, 166, 167,
|
||||||
|
168, 169, 170, 171,
|
||||||
|
172, 173, 174, 175,
|
||||||
|
|
||||||
|
192, 193, 194, 195,
|
||||||
|
196, 197, 198, 199,
|
||||||
|
200, 201, 202, 203,
|
||||||
|
204, 205, 206, 207,
|
||||||
|
|
||||||
|
224, 225, 226, 227,
|
||||||
|
228, 229, 230, 231,
|
||||||
|
232, 233, 234, 235,
|
||||||
|
236, 237, 238, 239,
|
||||||
|
}, 4, 4, 2, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 3, low: 0, high: 4, step: 2,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4)
|
||||||
|
},
|
||||||
|
want: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 1, 2, 3,
|
||||||
|
4, 5, 6, 7,
|
||||||
|
8, 9, 10, 11,
|
||||||
|
12, 13, 14, 15,
|
||||||
|
|
||||||
|
16, 17, 18, 19,
|
||||||
|
20, 21, 22, 23,
|
||||||
|
24, 25, 26, 27,
|
||||||
|
28, 29, 30, 31,
|
||||||
|
|
||||||
|
32, 33, 34, 35,
|
||||||
|
36, 37, 38, 39,
|
||||||
|
40, 41, 42, 43,
|
||||||
|
44, 45, 46, 47,
|
||||||
|
|
||||||
|
48, 49, 50, 51,
|
||||||
|
52, 53, 54, 55,
|
||||||
|
56, 57, 58, 59,
|
||||||
|
60, 61, 62, 63,
|
||||||
|
|
||||||
|
128, 129, 130, 131,
|
||||||
|
132, 133, 134, 135,
|
||||||
|
136, 137, 138, 139,
|
||||||
|
140, 141, 142, 143,
|
||||||
|
|
||||||
|
144, 145, 146, 147,
|
||||||
|
148, 149, 150, 151,
|
||||||
|
152, 153, 154, 155,
|
||||||
|
156, 157, 158, 159,
|
||||||
|
|
||||||
|
160, 161, 162, 163,
|
||||||
|
164, 165, 166, 167,
|
||||||
|
168, 169, 170, 171,
|
||||||
|
172, 173, 174, 175,
|
||||||
|
|
||||||
|
176, 177, 178, 179,
|
||||||
|
180, 181, 182, 183,
|
||||||
|
184, 185, 186, 187,
|
||||||
|
188, 189, 190, 191,
|
||||||
|
}, 4, 4, 4, 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
name := fmt.Sprintf("dim=%d,low=%d,high=%d,step=%d", tt.dim, tt.low, tt.high, tt.step)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
ctx := setup(t)
|
||||||
|
got := tt.input(ctx).Slice(ctx, tt.dim, tt.low, tt.high, tt.step)
|
||||||
|
got = got.Contiguous(ctx)
|
||||||
|
if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" {
|
||||||
|
t.Errorf("Slice() result mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitSections(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
dim int
|
||||||
|
sections []int
|
||||||
|
input func(ml.Context) ml.Tensor
|
||||||
|
want []func(ml.Context) ml.Tensor
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
dim: 0, sections: []int{1, 1, 1},
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{0, 3, 6, 9}, 1, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{1, 4, 7, 10}, 1, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{2, 5, 8, 11}, 1, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 1, sections: []int{1, 3},
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{0, 1, 2}, 3, 1)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
3, 4, 5,
|
||||||
|
6, 7, 8,
|
||||||
|
9, 10, 11,
|
||||||
|
}, 3, 3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 0, sections: []int{2, 2},
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 4, 3)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 1,
|
||||||
|
4, 5,
|
||||||
|
8, 9,
|
||||||
|
}, 2, 3)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
2, 3,
|
||||||
|
6, 7,
|
||||||
|
10, 11,
|
||||||
|
}, 2, 3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 1, sections: []int{1, 2},
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 4, 3)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{0, 1, 2, 3}, 4, 1)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
4, 5, 6, 7,
|
||||||
|
8, 9, 10, 11,
|
||||||
|
}, 4, 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(fmt.Sprintf("sections=%v", tt.sections), func(t *testing.T) {
|
||||||
|
ctx := setup(t)
|
||||||
|
got := tt.input(ctx).ChunkSections(ctx, tt.dim, tt.sections...)
|
||||||
|
|
||||||
|
for i := range got {
|
||||||
|
got[i] = got[i].Contiguous(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(got...).Compute(got...)
|
||||||
|
for i, want := range tt.want {
|
||||||
|
if diff := cmp.Diff(want(ctx), got[i], EquateTensors(ctx)); diff != "" {
|
||||||
|
t.Errorf("SplitSections() section %d mismatch (-want +got):\n%s", i, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunk(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
dim int
|
||||||
|
chunk int
|
||||||
|
input func(ml.Context) ml.Tensor
|
||||||
|
want []func(ml.Context) ml.Tensor
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
dim: 0, chunk: 1,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{0, 3, 6, 9}, 1, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{1, 4, 7, 10}, 1, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{2, 5, 8, 11}, 1, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 1, chunk: 2,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 1, 2,
|
||||||
|
3, 4, 5,
|
||||||
|
}, 3, 2)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
6, 7, 8,
|
||||||
|
9, 10, 11,
|
||||||
|
}, 3, 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 0, chunk: 2,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 1,
|
||||||
|
3, 4,
|
||||||
|
6, 7,
|
||||||
|
9, 10,
|
||||||
|
}, 2, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
2,
|
||||||
|
5,
|
||||||
|
8,
|
||||||
|
11,
|
||||||
|
}, 1, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(fmt.Sprintf("dim=%d,chunk=%d", tt.dim, tt.chunk), func(t *testing.T) {
|
||||||
|
ctx := setup(t)
|
||||||
|
got := tt.input(ctx).Chunk(ctx, tt.dim, tt.chunk)
|
||||||
|
|
||||||
|
for i := range got {
|
||||||
|
got[i] = got[i].Contiguous(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(got...).Compute(got...)
|
||||||
|
for i, want := range tt.want {
|
||||||
|
if diff := cmp.Diff(want(ctx), got[i], EquateTensors(ctx)); diff != "" {
|
||||||
|
t.Errorf("Split() section %d mismatch (-want +got):\n%s", i, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user