backend: API to support full precision matmul

Most tensor backends try to optimize performance by using a lower
precision for matmuls. However, some operations (such as kq) on
some models are sensitive to this and require full precision.
This commit is contained in:
Jesse Gross
2025-02-13 10:01:14 -08:00
committed by Jesse Gross
parent 4d4463b2bd
commit d773b7d671
4 changed files with 12 additions and 2 deletions

View File

@@ -80,7 +80,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Softmax(ctx)

View File

@@ -37,7 +37,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
if mask != nil {