From b48083f33f6784374106f7fcfc29f682af74a6ed Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 13 Nov 2025 13:28:21 -0800 Subject: [PATCH] ml: add slice operation (#12870) * slice * chunk, chunksections --- ml/backend.go | 4 + ml/backend/ggml/ggml.go | 63 ++++ ml/backend/ggml/ggml_test.go | 707 ++++++++++++++++++++++++++++++++++- 3 files changed, 773 insertions(+), 1 deletion(-) diff --git a/ml/backend.go b/ml/backend.go index b07039e217..c6fadb7f96 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -198,6 +198,10 @@ type Tensor interface { Copy(ctx Context, t2 Tensor) Tensor Duplicate(ctx Context) Tensor + Slice(ctx Context, dim, low, high, step int) Tensor + Chunk(ctx Context, dim int, size int) []Tensor + ChunkSections(ctx Context, dim int, sections ...int) []Tensor + TopK(ctx Context, k int) Tensor Argsort(ctx Context) Tensor Mean(ctx Context) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 5fa0a9ec0e..e18d2f387e 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1738,3 +1738,66 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor { t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)), } } + +// Slice returns a view of the tensor sliced along dim from low to high in step steps. +// Slice panics if the dimension is invalid or the slice parameters are out of range. +// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape. +func (t *Tensor) Slice(ctx ml.Context, dim int, low, high, step int) ml.Tensor { + if dim < 0 || dim >= C.GGML_MAX_DIMS { + panic("invalid dimension") + } else if low < 0 || high > t.Dim(dim) || low >= high || step < 1 { + panic("invalid slice parameters") + } + + if dim == 0 && step > 1 { + // dim=0,step>1 is a special case so handle it here first + return t.View(ctx, + low*t.Stride(0), 1, + step*t.Stride(0), (high-low+1)/step, + t.Stride(1), t.Dim(1), + // preserve dim 3 by merging it into dim 2 + t.Stride(2), t.Dim(2)*t.Dim(3), + ).Contiguous(ctx, (high-low+1)/step, t.Dim(1), t.Dim(2), t.Dim(3)) + } + + args := []int{ + low * t.Stride(dim), t.Dim(0), + t.Stride(1), t.Dim(1), + t.Stride(2), t.Dim(2), + t.Stride(3), t.Dim(3), + } + + if step == 1 { + args[dim*2+1] = high - low + return t.View(ctx, args[0], args[1:]...) + } else { + args[dim*2] = step * t.Stride(dim) + args[dim*2+1] = (high - low + 1) / step + return t.View(ctx, args[0], args[1:]...) + } +} + +// Chunk the tensor into chunk sized tensors along dim. Each sub-tensor is a view of +// the original. +func (t *Tensor) Chunk(ctx ml.Context, dim, chunk int) []ml.Tensor { + sections := make([]int, 0, t.Dim(dim)/chunk+1) + for rest := t.Dim(dim); rest > 0; rest -= chunk { + sections = append(sections, min(chunk, rest)) + } + return t.ChunkSections(ctx, dim, sections...) +} + +// ChunkSections split the tensor into section sized tensors along dim. Each sub-tensor is a +// view of the original. The size of the dim must equal the sum of sections. +func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Tensor { + var offset int + s := make([]ml.Tensor, len(sections)) + for i, section := range sections { + s[i] = t.Slice(ctx, dim, offset, offset+section, 1) + offset += section + } + if offset != t.Dim(dim) { + panic("sections do not sum to tensor dimension") + } + return s +} diff --git a/ml/backend/ggml/ggml_test.go b/ml/backend/ggml/ggml_test.go index 31dfdb7bb2..efd3a455cf 100644 --- a/ml/backend/ggml/ggml_test.go +++ b/ml/backend/ggml/ggml_test.go @@ -2,6 +2,7 @@ package ggml import ( "errors" + "fmt" "os" "testing" @@ -368,10 +369,714 @@ func TestPermute(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { ctx := setup(t) - got := tt.input(ctx).Permute(ctx, tt.shape...).Contiguous(ctx) + got := tt.input(ctx).Permute(ctx, tt.shape...) + got = got.Contiguous(ctx) if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" { t.Errorf("Permute() result mismatch (-want +got):\n%s", diff) } }) } } + +func TestSlice(t *testing.T) { + cases := []struct { + dim int + low int + high int + step int + input func(ml.Context) ml.Tensor + want func(ml.Context) ml.Tensor + }{ + { + dim: 0, low: 1, high: 3, step: 1, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4) + }, + want: func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 1, 2, + 5, 6, + 9, 10, + 13, 14, + + 17, 18, + 21, 22, + 25, 26, + 29, 30, + + 33, 34, + 37, 38, + 41, 42, + 45, 46, + + 49, 50, + 53, 54, + 57, 58, + 61, 62, + + 65, 66, + 69, 70, + 73, 74, + 77, 78, + + 81, 82, + 85, 86, + 89, 90, + 93, 94, + + 97, 98, + 101, 102, + 105, 106, + 109, 110, + + 113, 114, + 117, 118, + 121, 122, + 125, 126, + + 129, 130, + 133, 134, + 137, 138, + 141, 142, + + 145, 146, + 149, 150, + 153, 154, + 157, 158, + + 161, 162, + 165, 166, + 169, 170, + 173, 174, + + 177, 178, + 181, 182, + 185, 186, + 189, 190, + + 193, 194, + 197, 198, + 201, 202, + 205, 206, + + 209, 210, + 213, 214, + 217, 218, + 221, 222, + + 225, 226, + 229, 230, + 233, 234, + 237, 238, + + 241, 242, + 245, 246, + 249, 250, + 253, 254, + }, 2, 4, 4, 4) + }, + }, + { + dim: 1, low: 1, high: 3, step: 1, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4) + }, + want: func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 4, 5, 6, 7, + 8, 9, 10, 11, + + 20, 21, 22, 23, + 24, 25, 26, 27, + + 36, 37, 38, 39, + 40, 41, 42, 43, + + 52, 53, 54, 55, + 56, 57, 58, 59, + + 68, 69, 70, 71, + 72, 73, 74, 75, + + 84, 85, 86, 87, + 88, 89, 90, 91, + + 100, 101, 102, 103, + 104, 105, 106, 107, + + 116, 117, 118, 119, + 120, 121, 122, 123, + + 132, 133, 134, 135, + 136, 137, 138, 139, + + 148, 149, 150, 151, + 152, 153, 154, 155, + + 164, 165, 166, 167, + 168, 169, 170, 171, + + 180, 181, 182, 183, + 184, 185, 186, 187, + + 196, 197, 198, 199, + 200, 201, 202, 203, + + 212, 213, 214, 215, + 216, 217, 218, 219, + + 228, 229, 230, 231, + 232, 233, 234, 235, + + 244, 245, 246, 247, + 248, 249, 250, 251, + }, 4, 2, 4, 4) + }, + }, + { + dim: 2, low: 1, high: 3, step: 1, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4) + }, + want: func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 16, 17, 18, 19, + 20, 21, 22, 23, + 24, 25, 26, 27, + 28, 29, 30, 31, + + 32, 33, 34, 35, + 36, 37, 38, 39, + 40, 41, 42, 43, + 44, 45, 46, 47, + + 80, 81, 82, 83, + 84, 85, 86, 87, + 88, 89, 90, 91, + 92, 93, 94, 95, + + 96, 97, 98, 99, + 100, 101, 102, 103, + 104, 105, 106, 107, + 108, 109, 110, 111, + + 144, 145, 146, 147, + 148, 149, 150, 151, + 152, 153, 154, 155, + 156, 157, 158, 159, + + 160, 161, 162, 163, + 164, 165, 166, 167, + 168, 169, 170, 171, + 172, 173, 174, 175, + + 208, 209, 210, 211, + 212, 213, 214, 215, + 216, 217, 218, 219, + 220, 221, 222, 223, + + 224, 225, 226, 227, + 228, 229, 230, 231, + 232, 233, 234, 235, + 236, 237, 238, 239, + }, 4, 4, 2, 4) + }, + }, + { + dim: 3, low: 1, high: 3, step: 1, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4) + }, + want: func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 64, 65, 66, 67, + 68, 69, 70, 71, + 72, 73, 74, 75, + 76, 77, 78, 79, + + 80, 81, 82, 83, + 84, 85, 86, 87, + 88, 89, 90, 91, + 92, 93, 94, 95, + + 96, 97, 98, 99, + 100, 101, 102, 103, + 104, 105, 106, 107, + 108, 109, 110, 111, + + 112, 113, 114, 115, + 116, 117, 118, 119, + 120, 121, 122, 123, + 124, 125, 126, 127, + + 128, 129, 130, 131, + 132, 133, 134, 135, + 136, 137, 138, 139, + 140, 141, 142, 143, + + 144, 145, 146, 147, + 148, 149, 150, 151, + 152, 153, 154, 155, + 156, 157, 158, 159, + + 160, 161, 162, 163, + 164, 165, 166, 167, + 168, 169, 170, 171, + 172, 173, 174, 175, + + 176, 177, 178, 179, + 180, 181, 182, 183, + 184, 185, 186, 187, + 188, 189, 190, 191, + }, 4, 4, 4, 2) + }, + }, + { + dim: 0, low: 0, high: 4, step: 2, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4) + }, + want: func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 2, + 4, 6, + 8, 10, + 12, 14, + + 16, 18, + 20, 22, + 24, 26, + 28, 30, + + 32, 34, + 36, 38, + 40, 42, + 44, 46, + + 48, 50, + 52, 54, + 56, 58, + 60, 62, + + 64, 66, + 68, 70, + 72, 74, + 76, 78, + + 80, 82, + 84, 86, + 88, 90, + 92, 94, + + 96, 98, + 100, 102, + 104, 106, + 108, 110, + + 112, 114, + 116, 118, + 120, 122, + 124, 126, + + 128, 130, + 132, 134, + 136, 138, + 140, 142, + + 144, 146, + 148, 150, + 152, 154, + 156, 158, + + 160, 162, + 164, 166, + 168, 170, + 172, 174, + + 176, 178, + 180, 182, + 184, 186, + 188, 190, + + 192, 194, + 196, 198, + 200, 202, + 204, 206, + + 208, 210, + 212, 214, + 216, 218, + 220, 222, + + 224, 226, + 228, 230, + 232, 234, + 236, 238, + + 240, 242, + 244, 246, + 248, 250, + 252, 254, + }, 2, 4, 4, 4) + }, + }, + { + dim: 1, low: 0, high: 4, step: 2, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4) + }, + want: func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 1, 2, 3, + 8, 9, 10, 11, + + 16, 17, 18, 19, + 24, 25, 26, 27, + + 32, 33, 34, 35, + 40, 41, 42, 43, + + 48, 49, 50, 51, + 56, 57, 58, 59, + + 64, 65, 66, 67, + 72, 73, 74, 75, + + 80, 81, 82, 83, + 88, 89, 90, 91, + + 96, 97, 98, 99, + 104, 105, 106, 107, + + 112, 113, 114, 115, + 120, 121, 122, 123, + + 128, 129, 130, 131, + 136, 137, 138, 139, + + 144, 145, 146, 147, + 152, 153, 154, 155, + + 160, 161, 162, 163, + 168, 169, 170, 171, + + 176, 177, 178, 179, + 184, 185, 186, 187, + + 192, 193, 194, 195, + 200, 201, 202, 203, + + 208, 209, 210, 211, + 216, 217, 218, 219, + + 224, 225, 226, 227, + 232, 233, 234, 235, + + 240, 241, 242, 243, + 248, 249, 250, 251, + }, 4, 2, 4, 4) + }, + }, + { + dim: 2, low: 0, high: 4, step: 2, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4) + }, + want: func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + + 32, 33, 34, 35, + 36, 37, 38, 39, + 40, 41, 42, 43, + 44, 45, 46, 47, + + 64, 65, 66, 67, + 68, 69, 70, 71, + 72, 73, 74, 75, + 76, 77, 78, 79, + + 96, 97, 98, 99, + 100, 101, 102, 103, + 104, 105, 106, 107, + 108, 109, 110, 111, + + 128, 129, 130, 131, + 132, 133, 134, 135, + 136, 137, 138, 139, + 140, 141, 142, 143, + + 160, 161, 162, 163, + 164, 165, 166, 167, + 168, 169, 170, 171, + 172, 173, 174, 175, + + 192, 193, 194, 195, + 196, 197, 198, 199, + 200, 201, 202, 203, + 204, 205, 206, 207, + + 224, 225, 226, 227, + 228, 229, 230, 231, + 232, 233, 234, 235, + 236, 237, 238, 239, + }, 4, 4, 2, 4) + }, + }, + { + dim: 3, low: 0, high: 4, step: 2, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 4*4*4*4, 1, ml.DTypeF32).Reshape(ctx, 4, 4, 4, 4) + }, + want: func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + + 16, 17, 18, 19, + 20, 21, 22, 23, + 24, 25, 26, 27, + 28, 29, 30, 31, + + 32, 33, 34, 35, + 36, 37, 38, 39, + 40, 41, 42, 43, + 44, 45, 46, 47, + + 48, 49, 50, 51, + 52, 53, 54, 55, + 56, 57, 58, 59, + 60, 61, 62, 63, + + 128, 129, 130, 131, + 132, 133, 134, 135, + 136, 137, 138, 139, + 140, 141, 142, 143, + + 144, 145, 146, 147, + 148, 149, 150, 151, + 152, 153, 154, 155, + 156, 157, 158, 159, + + 160, 161, 162, 163, + 164, 165, 166, 167, + 168, 169, 170, 171, + 172, 173, 174, 175, + + 176, 177, 178, 179, + 180, 181, 182, 183, + 184, 185, 186, 187, + 188, 189, 190, 191, + }, 4, 4, 4, 2) + }, + }, + } + + for _, tt := range cases { + name := fmt.Sprintf("dim=%d,low=%d,high=%d,step=%d", tt.dim, tt.low, tt.high, tt.step) + t.Run(name, func(t *testing.T) { + ctx := setup(t) + got := tt.input(ctx).Slice(ctx, tt.dim, tt.low, tt.high, tt.step) + got = got.Contiguous(ctx) + if diff := cmp.Diff(tt.want(ctx), got, EquateTensors(ctx)); diff != "" { + t.Errorf("Slice() result mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestSplitSections(t *testing.T) { + cases := []struct { + dim int + sections []int + input func(ml.Context) ml.Tensor + want []func(ml.Context) ml.Tensor + }{ + { + dim: 0, sections: []int{1, 1, 1}, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{0, 3, 6, 9}, 1, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{1, 4, 7, 10}, 1, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{2, 5, 8, 11}, 1, 4) + }, + }, + }, + { + dim: 1, sections: []int{1, 3}, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{0, 1, 2}, 3, 1) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 3, 4, 5, + 6, 7, 8, + 9, 10, 11, + }, 3, 3) + }, + }, + }, + { + dim: 0, sections: []int{2, 2}, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 4, 3) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 1, + 4, 5, + 8, 9, + }, 2, 3) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 2, 3, + 6, 7, + 10, 11, + }, 2, 3) + }, + }, + }, + { + dim: 1, sections: []int{1, 2}, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 4, 3) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{0, 1, 2, 3}, 4, 1) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 4, 5, 6, 7, + 8, 9, 10, 11, + }, 4, 2) + }, + }, + }, + } + + for _, tt := range cases { + t.Run(fmt.Sprintf("sections=%v", tt.sections), func(t *testing.T) { + ctx := setup(t) + got := tt.input(ctx).ChunkSections(ctx, tt.dim, tt.sections...) + + for i := range got { + got[i] = got[i].Contiguous(ctx) + } + + ctx.Forward(got...).Compute(got...) + for i, want := range tt.want { + if diff := cmp.Diff(want(ctx), got[i], EquateTensors(ctx)); diff != "" { + t.Errorf("SplitSections() section %d mismatch (-want +got):\n%s", i, diff) + } + } + }) + } +} + +func TestChunk(t *testing.T) { + cases := []struct { + dim int + chunk int + input func(ml.Context) ml.Tensor + want []func(ml.Context) ml.Tensor + }{ + { + dim: 0, chunk: 1, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{0, 3, 6, 9}, 1, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{1, 4, 7, 10}, 1, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{2, 5, 8, 11}, 1, 4) + }, + }, + }, + { + dim: 1, chunk: 2, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 1, 2, + 3, 4, 5, + }, 3, 2) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 6, 7, 8, + 9, 10, 11, + }, 3, 2) + }, + }, + }, + { + dim: 0, chunk: 2, + input: func(ctx ml.Context) ml.Tensor { + return ctx.Arange(0, 12, 1, ml.DTypeF32).Reshape(ctx, 3, 4) + }, + want: []func(ml.Context) ml.Tensor{ + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 0, 1, + 3, 4, + 6, 7, + 9, 10, + }, 2, 4) + }, + func(ctx ml.Context) ml.Tensor { + return ctx.FromFloats([]float32{ + 2, + 5, + 8, + 11, + }, 1, 4) + }, + }, + }, + } + + for _, tt := range cases { + t.Run(fmt.Sprintf("dim=%d,chunk=%d", tt.dim, tt.chunk), func(t *testing.T) { + ctx := setup(t) + got := tt.input(ctx).Chunk(ctx, tt.dim, tt.chunk) + + for i := range got { + got[i] = got[i].Contiguous(ctx) + } + + ctx.Forward(got...).Compute(got...) + for i, want := range tt.want { + if diff := cmp.Diff(want(ctx), got[i], EquateTensors(ctx)); diff != "" { + t.Errorf("Split() section %d mismatch (-want +got):\n%s", i, diff) + } + } + }) + } +}