From 55e5776c44659a153fa1b9a0316ec1b6a834a4b8 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 27 Feb 2025 14:52:39 -0800 Subject: [PATCH] ggml-backend: Store parent backend as part of tensor It can be important for a tensor to know what backend it came from - for example, to know if flash attention is enabled. --- ml/backend/ggml/ggml.go | 42 ++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index bddaad463..24943111c 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -219,7 +219,7 @@ func (b *Backend) Get(name string) ml.Tensor { for _, c := range append(b.gpus, b.cpus...) { if t := C.ggml_get_tensor(c.ctx, cname); t != nil { - return &Tensor{t: t} + return &Tensor{b: b, t: t} } } @@ -330,7 +330,7 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_set_zero(t) - return &Tensor{t: t} + return &Tensor{b: c.b, t: t} } func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) { @@ -339,7 +339,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u if n == 0 { var shape C.int64_t = 0 t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape) - return &Tensor{t: t}, nil + return &Tensor{b: ctx.b, t: t}, nil } for _, v := range shape { @@ -354,7 +354,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t)) C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b)) C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t)) - return &Tensor{t: t}, nil + return &Tensor{b: ctx.b, t: t}, nil } func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { @@ -372,6 +372,7 @@ func (c *Context) Close() { } type Tensor struct { + b *Backend t *C.struct_ggml_tensor sync func() } @@ -438,6 +439,7 @@ func (t *Tensor) DType() ml.DType { func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -452,24 +454,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)), } } func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_cont(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -479,12 +485,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32) return &Tensor{ + b: t.b, t: mul, } } func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { - tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) if b != nil { tt = tt.Add(ctx, b) } @@ -493,7 +500,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso } func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { - return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) } func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { @@ -502,6 +509,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } @@ -512,18 +520,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t), } } @@ -532,18 +543,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { switch len(shape) { case 1: return &Tensor{ + b: t.b, t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])), } case 2: return &Tensor{ + b: t.b, t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])), } case 3: return &Tensor{ + b: t.b, t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])), } case 4: return &Tensor{ + b: t.b, t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])), } default: @@ -553,18 +568,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)), } } func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_soft_max(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t), } } @@ -575,6 +593,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { } return &Tensor{ + b: t.b, t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])), } } @@ -583,10 +602,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { switch len(shape) { case 1: return &Tensor{ + b: t.b, t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)), } case 3: return &Tensor{ + b: t.b, t: C.ggml_view_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.size_t(shape[1]), @@ -594,6 +615,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } case 5: return &Tensor{ + b: t.b, t: C.ggml_view_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.size_t(shape[1]), C.size_t(shape[3]), @@ -601,6 +623,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { } case 7: return &Tensor{ + b: t.b, t: C.ggml_view_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]), C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]), @@ -617,7 +640,7 @@ const ( func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor { if ropeFactors == nil { - ropeFactors = &Tensor{} + ropeFactors = &Tensor{b: t.b} } dequant := t.t @@ -626,6 +649,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi } return &Tensor{ + b: t.b, t: C.ggml_rope_ext( ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, C.int(ropeDim), @@ -643,18 +667,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), } } func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), } } func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { return &Tensor{ + b: t.b, t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)), } } @@ -670,6 +697,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T kq := key.MulmatFullPrec(ctx, query) kq = &Tensor{ + b: t.b, t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), }