mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 06:17:55 +01:00
chunk, chunksections
This commit is contained in:
@@ -199,6 +199,8 @@ type Tensor interface {
|
|||||||
Duplicate(ctx Context) Tensor
|
Duplicate(ctx Context) Tensor
|
||||||
|
|
||||||
Slice(ctx Context, dim, low, high, step int) Tensor
|
Slice(ctx Context, dim, low, high, step int) Tensor
|
||||||
|
Chunk(ctx Context, dim int, size int) []Tensor
|
||||||
|
ChunkSections(ctx Context, dim int, sections ...int) []Tensor
|
||||||
|
|
||||||
TopK(ctx Context, k int) Tensor
|
TopK(ctx Context, k int) Tensor
|
||||||
Argsort(ctx Context) Tensor
|
Argsort(ctx Context) Tensor
|
||||||
|
|||||||
@@ -1780,3 +1780,28 @@ func (t *Tensor) Slice(ctx ml.Context, dim int, low, high, step int) ml.Tensor {
|
|||||||
return t.View(ctx, args[0], args[1:]...)
|
return t.View(ctx, args[0], args[1:]...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Chunk the tensor into chunk sized tensors along dim. Each sub-tensor is a view of
|
||||||
|
// the original.
|
||||||
|
func (t *Tensor) Chunk(ctx ml.Context, dim, chunk int) []ml.Tensor {
|
||||||
|
sections := make([]int, 0, t.Dim(dim)/chunk+1)
|
||||||
|
for rest := t.Dim(dim); rest > 0; rest -= chunk {
|
||||||
|
sections = append(sections, min(chunk, rest))
|
||||||
|
}
|
||||||
|
return t.ChunkSections(ctx, dim, sections...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChunkSections split the tensor into section sized tensors along dim. Each sub-tensor is a
|
||||||
|
// view of the original. The size of the dim must equal the sum of sections.
|
||||||
|
func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Tensor {
|
||||||
|
var offset int
|
||||||
|
s := make([]ml.Tensor, len(sections))
|
||||||
|
for i, section := range sections {
|
||||||
|
s[i] = t.Slice(ctx, dim, offset, offset+section, 1)
|
||||||
|
offset += section
|
||||||
|
}
|
||||||
|
if offset != t.Dim(dim) {
|
||||||
|
panic("sections do not sum to tensor dimension")
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|||||||
@@ -369,7 +369,8 @@ func TestPermute(t *testing.T) {
|
|||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := setup(t)
|
ctx := setup(t)
|
||||||
got := tt.input(ctx).Permute(ctx, tt.shape...).Contiguous(ctx)
|
got := tt.input(ctx).Permute(ctx, tt.shape...)
|
||||||
|
got = got.Contiguous(ctx)
|
||||||
if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" {
|
if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" {
|
||||||
t.Errorf("Permute() result mismatch (-want +got):\n%s", diff)
|
t.Errorf("Permute() result mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
@@ -880,10 +881,202 @@ func TestSlice(t *testing.T) {
|
|||||||
name := fmt.Sprintf("dim=%d,low=%d,high=%d,step=%d", tt.dim, tt.low, tt.high, tt.step)
|
name := fmt.Sprintf("dim=%d,low=%d,high=%d,step=%d", tt.dim, tt.low, tt.high, tt.step)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
ctx := setup(t)
|
ctx := setup(t)
|
||||||
got := tt.input(ctx).Slice(ctx, tt.dim, tt.low, tt.high, tt.step).Contiguous(ctx)
|
got := tt.input(ctx).Slice(ctx, tt.dim, tt.low, tt.high, tt.step)
|
||||||
|
got = got.Contiguous(ctx)
|
||||||
if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" {
|
if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" {
|
||||||
t.Errorf("Slice() result mismatch (-want +got):\n%s", diff)
|
t.Errorf("Slice() result mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSplitSections(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
dim int
|
||||||
|
sections []int
|
||||||
|
input func(ml.Context) ml.Tensor
|
||||||
|
want []func(ml.Context) ml.Tensor
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
dim: 0, sections: []int{1, 1, 1},
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{0, 3, 6, 9}, 1, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{1, 4, 7, 10}, 1, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{2, 5, 8, 11}, 1, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 1, sections: []int{1, 3},
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{0, 1, 2}, 3, 1)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
3, 4, 5,
|
||||||
|
6, 7, 8,
|
||||||
|
9, 10, 11,
|
||||||
|
}, 3, 3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 0, sections: []int{2, 2},
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 4, 3)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 1,
|
||||||
|
4, 5,
|
||||||
|
8, 9,
|
||||||
|
}, 2, 3)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
2, 3,
|
||||||
|
6, 7,
|
||||||
|
10, 11,
|
||||||
|
}, 2, 3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 1, sections: []int{1, 2},
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 4, 3)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{0, 1, 2, 3}, 4, 1)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
4, 5, 6, 7,
|
||||||
|
8, 9, 10, 11,
|
||||||
|
}, 4, 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(fmt.Sprintf("sections=%v", tt.sections), func(t *testing.T) {
|
||||||
|
ctx := setup(t)
|
||||||
|
got := tt.input(ctx).ChunkSections(ctx, tt.dim, tt.sections...)
|
||||||
|
|
||||||
|
for i := range got {
|
||||||
|
got[i] = got[i].Contiguous(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(got...).Compute(got...)
|
||||||
|
for i, want := range tt.want {
|
||||||
|
if diff := cmp.Diff(want(ctx), got[i], EquateTensors(ctx)); diff != "" {
|
||||||
|
t.Errorf("SplitSections() section %d mismatch (-want +got):\n%s", i, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunk(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
dim int
|
||||||
|
chunk int
|
||||||
|
input func(ml.Context) ml.Tensor
|
||||||
|
want []func(ml.Context) ml.Tensor
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
dim: 0, chunk: 1,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{0, 3, 6, 9}, 1, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{1, 4, 7, 10}, 1, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{2, 5, 8, 11}, 1, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 1, chunk: 2,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 1, 2,
|
||||||
|
3, 4, 5,
|
||||||
|
}, 3, 2)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
6, 7, 8,
|
||||||
|
9, 10, 11,
|
||||||
|
}, 3, 2)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
dim: 0, chunk: 2,
|
||||||
|
input: func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4)
|
||||||
|
},
|
||||||
|
want: []func(ml.Context) ml.Tensor{
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
0, 1,
|
||||||
|
3, 4,
|
||||||
|
6, 7,
|
||||||
|
9, 10,
|
||||||
|
}, 2, 4)
|
||||||
|
},
|
||||||
|
func(ctx ml.Context) ml.Tensor {
|
||||||
|
return ctx.FromFloats([]float32{
|
||||||
|
2,
|
||||||
|
5,
|
||||||
|
8,
|
||||||
|
11,
|
||||||
|
}, 1, 4)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(fmt.Sprintf("dim=%d,chunk=%d", tt.dim, tt.chunk), func(t *testing.T) {
|
||||||
|
ctx := setup(t)
|
||||||
|
got := tt.input(ctx).Chunk(ctx, tt.dim, tt.chunk)
|
||||||
|
|
||||||
|
for i := range got {
|
||||||
|
got[i] = got[i].Contiguous(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Forward(got...).Compute(got...)
|
||||||
|
for i, want := range tt.want {
|
||||||
|
if diff := cmp.Diff(want(ctx), got[i], EquateTensors(ctx)); diff != "" {
|
||||||
|
t.Errorf("Split() section %d mismatch (-want +got):\n%s", i, diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user