From 05ccb17c6e101637ef0fa8fbd88952da2e7e2ca4 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 19 Aug 2025 09:52:18 -0700 Subject: [PATCH] kvcache: Use Cast instead of Copy for flash attention masks Flash attention kernels require the mask of the KV cache be a F16 rather than an F32. We can use the GGML operation ggml_cast to do this rather than doing it ourselves, which allows reuse of a preallocated buffer in the graph rather than allocating a new one for each batch. This improves token generation performance with flash attention by 10-30% (with gpt-oss). This also makes performance with flash attention better than without it, as expected. --- kvcache/causal.go | 4 +--- ml/backend.go | 1 + ml/backend/ggml/ggml.go | 44 +++++++++++++++++++++++++---------------- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/kvcache/causal.go b/kvcache/causal.go index 96d8067eb4..31f5523310 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -378,9 +378,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize) if c.config.MaskDType != ml.DTypeF32 { - out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) - ctx.Forward(maskTensor.Copy(ctx, out)) - maskTensor = out + maskTensor = maskTensor.Cast(ctx, c.config.MaskDType) } return maskTensor diff --git a/ml/backend.go b/ml/backend.go index 638a05d144..705724821b 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -396,6 +396,7 @@ type Tensor interface { Shape() []int DType() DType + Cast(ctx Context, dtype DType) Tensor Bytes() []byte Floats() []float32 diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index f8121cc530..13d898aadd 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -843,23 +843,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { panic("set Input or Layer before creating tensors") } - var cdtype uint32 - switch dtype { - case ml.DTypeF32: - cdtype = C.GGML_TYPE_F32 - case ml.DTypeF16: - cdtype = C.GGML_TYPE_F16 - case ml.DTypeQ80: - cdtype = C.GGML_TYPE_Q8_0 - case ml.DTypeQ40: - cdtype = C.GGML_TYPE_Q4_0 - case ml.DTypeI32: - cdtype = C.GGML_TYPE_I32 - case ml.DTypeMXFP4: - cdtype = C.GGML_TYPE_MXFP4 - default: - panic("unsupported dtype") - } + cdtype := ggmlDType(dtype) if len(shape) < 1 || shape[0] == 0 { var shape C.int64_t = 0 @@ -1056,6 +1040,32 @@ func (t *Tensor) DType() ml.DType { } } +func ggmlDType(dtype ml.DType) uint32 { + switch dtype { + case ml.DTypeF32: + return C.GGML_TYPE_F32 + case ml.DTypeF16: + return C.GGML_TYPE_F16 + case ml.DTypeQ80: + return C.GGML_TYPE_Q8_0 + case ml.DTypeQ40: + return C.GGML_TYPE_Q4_0 + case ml.DTypeI32: + return C.GGML_TYPE_I32 + case ml.DTypeMXFP4: + return C.GGML_TYPE_MXFP4 + default: + panic("unsupported dtype") + } +} + +func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_cast(ctx.(*Context).ctx, t.t, ggmlDType(dtype)), + } +} + func (t *Tensor) Neg(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b,