diff --git a/ml/backend.go b/ml/backend.go index 8518024d11..c6fadb7f96 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -199,6 +199,8 @@ type Tensor interface { Duplicate(ctx Context) Tensor Slice(ctx Context, dim, low, high, step int) Tensor + Chunk(ctx Context, dim int, size int) []Tensor + ChunkSections(ctx Context, dim int, sections ...int) []Tensor TopK(ctx Context, k int) Tensor Argsort(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2219e2b09e..cd0f436f81 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1780,3 +1780,28 @@ func (t *Tensor) Slice(ctx ml.Context, dim int, low, high, step int) ml.Tensor { return t.View(ctx, args[0], args[1:]...) } } + +// Chunk the tensor into chunk sized tensors along dim. Each sub-tensor is a view of +// the original. +func (t *Tensor) Chunk(ctx ml.Context, dim, chunk int) []ml.Tensor { + sections := make([]int, 0, t.Dim(dim)/chunk+1) + for rest := t.Dim(dim); rest > 0; rest -= chunk { + sections = append(sections, min(chunk, rest)) + } + return t.ChunkSections(ctx, dim, sections...) +} + +// ChunkSections split the tensor into section sized tensors along dim. Each sub-tensor is a +// view of the original. The size of the dim must equal the sum of sections. +func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Tensor { + var offset int + s := make([]ml.Tensor, len(sections)) + for i, section := range sections { + s[i] = t.Slice(ctx, dim, offset, offset+section, 1) + offset += section + } + if offset != t.Dim(dim) { + panic("sections do not sum to tensor dimension") + } + return s +} diff --git a/ml/backend/ggml/ggml_test.go b/ml/backend/ggml/ggml_test.go index fd5039b0af..efd3a455cf 100644 --- a/ml/backend/ggml/ggml_test.go +++ b/ml/backend/ggml/ggml_test.go @@ -369,7 +369,8 @@ func TestPermute(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { ctx := setup(t) - got := tt.input(ctx).Permute(ctx, tt.shape...).Contiguous(ctx) + got := tt.input(ctx).Permute(ctx, tt.shape...) + got = got.Contiguous(ctx) if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" { t.Errorf("Permute() result mismatch (-want +got):\n%s", diff) } @@ -880,10 +881,202 @@ func TestSlice(t *testing.T) { name := fmt.Sprintf("dim=%d,low=%d,high=%d,step=%d", tt.dim, tt.low, tt.high, tt.step) t.Run(name, func(t *testing.T) { ctx := setup(t) - got := tt.input(ctx).Slice(ctx, tt.dim, tt.low, tt.high, tt.step).Contiguous(ctx) + got := tt.input(ctx).Slice(ctx, tt.dim, tt.low, tt.high, tt.step) + got = got.Contiguous(ctx) if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" { t.Errorf("Slice() result mismatch (-want +got):\n%s", diff) } }) } } + +func TestSplitSections(t *testing.T) { + cases := []struct { + dim int + sections []int + input func(ml.Context) ml.Tensor + want []func(ml.Context) ml.Tensor + }{ + { + dim: 0, sections: []int{1, 1, 1}, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{0, 3, 6, 9}, 1, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{1, 4, 7, 10}, 1, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{2, 5, 8, 11}, 1, 4) + }, + }, + }, + { + dim: 1, sections: []int{1, 3}, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{0, 1, 2}, 3, 1) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 3, 4, 5, + 6, 7, 8, + 9, 10, 11, + }, 3, 3) + }, + }, + }, + { + dim: 0, sections: []int{2, 2}, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 4, 3) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 1, + 4, 5, + 8, 9, + }, 2, 3) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 2, 3, + 6, 7, + 10, 11, + }, 2, 3) + }, + }, + }, + { + dim: 1, sections: []int{1, 2}, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 4, 3) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{0, 1, 2, 3}, 4, 1) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 4, 5, 6, 7, + 8, 9, 10, 11, + }, 4, 2) + }, + }, + }, + } + + for _, tt := range cases { + t.Run(fmt.Sprintf("sections=%v", tt.sections), func(t *testing.T) { + ctx := setup(t) + got := tt.input(ctx).ChunkSections(ctx, tt.dim, tt.sections...) + + for i := range got { + got[i] = got[i].Contiguous(ctx) + } + + ctx.Forward(got...).Compute(got...) + for i, want := range tt.want { + if diff := cmp.Diff(want(ctx), got[i], EquateTensors(ctx)); diff != "" { + t.Errorf("SplitSections() section %d mismatch (-want +got):\n%s", i, diff) + } + } + }) + } +} + +func TestChunk(t *testing.T) { + cases := []struct { + dim int + chunk int + input func(ml.Context) ml.Tensor + want []func(ml.Context) ml.Tensor + }{ + { + dim: 0, chunk: 1, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{0, 3, 6, 9}, 1, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{1, 4, 7, 10}, 1, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{2, 5, 8, 11}, 1, 4) + }, + }, + }, + { + dim: 1, chunk: 2, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 1, 2, + 3, 4, 5, + }, 3, 2) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 6, 7, 8, + 9, 10, 11, + }, 3, 2) + }, + }, + }, + { + dim: 0, chunk: 2, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 1, + 3, 4, + 6, 7, + 9, 10, + }, 2, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 2, + 5, + 8, + 11, + }, 1, 4) + }, + }, + }, + } + + for _, tt := range cases { + t.Run(fmt.Sprintf("dim=%d,chunk=%d", tt.dim, tt.chunk), func(t *testing.T) { + ctx := setup(t) + got := tt.input(ctx).Chunk(ctx, tt.dim, tt.chunk) + + for i := range got { + got[i] = got[i].Contiguous(ctx) + } + + ctx.Forward(got...).Compute(got...) + for i, want := range tt.want { + if diff := cmp.Diff(want(ctx), got[i], EquateTensors(ctx)); diff != "" { + t.Errorf("Split() section %d mismatch (-want +got):\n%s", i, diff) + } + } + }) + } +}