backend: API to support full precision matmul

Most tensor backends try to optimize performance by using a lower
precision for matmuls. However, some operations (such as kq) on
some models are sensitive to this and require full precision.
This commit is contained in:
Jesse Gross 2025-02-13 10:01:14 -08:00 committed by Jesse Gross
parent 4d4463b2bd
commit d773b7d671
4 changed files with 12 additions and 2 deletions

View File

@ -66,6 +66,7 @@ type Tensor interface {
Add(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
Softmax(ctx Context) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor

View File

@ -421,6 +421,15 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
t: mul,
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
if b != nil {

View File

@ -80,7 +80,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Softmax(ctx)

View File

@ -37,7 +37,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
if mask != nil {