mirror of
https://github.com/ollama/ollama.git
synced 2025-08-25 20:41:09 +02:00
kvcache: Use Cast instead of Copy for flash attention masks
Flash attention kernels require the mask of the KV cache be a F16 rather than an F32. We can use the GGML operation ggml_cast to do this rather than doing it ourselves, which allows reuse of a preallocated buffer in the graph rather than allocating a new one for each batch. This improves token generation performance with flash attention by 10-30% (with gpt-oss). This also makes performance with flash attention better than without it, as expected.
This commit is contained in:
@@ -378,9 +378,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
|||||||
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||||
|
|
||||||
if c.config.MaskDType != ml.DTypeF32 {
|
if c.config.MaskDType != ml.DTypeF32 {
|
||||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
|
||||||
maskTensor = out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return maskTensor
|
return maskTensor
|
||||||
|
@@ -396,6 +396,7 @@ type Tensor interface {
|
|||||||
|
|
||||||
Shape() []int
|
Shape() []int
|
||||||
DType() DType
|
DType() DType
|
||||||
|
Cast(ctx Context, dtype DType) Tensor
|
||||||
|
|
||||||
Bytes() []byte
|
Bytes() []byte
|
||||||
Floats() []float32
|
Floats() []float32
|
||||||
|
@@ -843,23 +843,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
|||||||
panic("set Input or Layer before creating tensors")
|
panic("set Input or Layer before creating tensors")
|
||||||
}
|
}
|
||||||
|
|
||||||
var cdtype uint32
|
cdtype := ggmlDType(dtype)
|
||||||
switch dtype {
|
|
||||||
case ml.DTypeF32:
|
|
||||||
cdtype = C.GGML_TYPE_F32
|
|
||||||
case ml.DTypeF16:
|
|
||||||
cdtype = C.GGML_TYPE_F16
|
|
||||||
case ml.DTypeQ80:
|
|
||||||
cdtype = C.GGML_TYPE_Q8_0
|
|
||||||
case ml.DTypeQ40:
|
|
||||||
cdtype = C.GGML_TYPE_Q4_0
|
|
||||||
case ml.DTypeI32:
|
|
||||||
cdtype = C.GGML_TYPE_I32
|
|
||||||
case ml.DTypeMXFP4:
|
|
||||||
cdtype = C.GGML_TYPE_MXFP4
|
|
||||||
default:
|
|
||||||
panic("unsupported dtype")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(shape) < 1 || shape[0] == 0 {
|
if len(shape) < 1 || shape[0] == 0 {
|
||||||
var shape C.int64_t = 0
|
var shape C.int64_t = 0
|
||||||
@@ -1056,6 +1040,32 @@ func (t *Tensor) DType() ml.DType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ggmlDType(dtype ml.DType) uint32 {
|
||||||
|
switch dtype {
|
||||||
|
case ml.DTypeF32:
|
||||||
|
return C.GGML_TYPE_F32
|
||||||
|
case ml.DTypeF16:
|
||||||
|
return C.GGML_TYPE_F16
|
||||||
|
case ml.DTypeQ80:
|
||||||
|
return C.GGML_TYPE_Q8_0
|
||||||
|
case ml.DTypeQ40:
|
||||||
|
return C.GGML_TYPE_Q4_0
|
||||||
|
case ml.DTypeI32:
|
||||||
|
return C.GGML_TYPE_I32
|
||||||
|
case ml.DTypeMXFP4:
|
||||||
|
return C.GGML_TYPE_MXFP4
|
||||||
|
default:
|
||||||
|
panic("unsupported dtype")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_cast(ctx.(*Context).ctx, t.t, ggmlDType(dtype)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||||
return &Tensor{
|
return &Tensor{
|
||||||
b: t.b,
|
b: t.b,
|
||||||
|
Reference in New Issue
Block a user