diff --git a/ml/backend.go b/ml/backend.go index 41679f3b3..acfb67637 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -66,6 +66,7 @@ type Tensor interface { Add(ctx Context, t2 Tensor) Tensor Mul(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor + MulmatFullPrec(ctx Context, t2 Tensor) Tensor Softmax(ctx Context) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index d039a3ea3..d1b2d646e 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -421,6 +421,15 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } +func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t) + C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32) + + return &Tensor{ + t: mul, + } +} + func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) if b != nil { diff --git a/model/llama/model.go b/model/llama/model.go index b151ea5d2..294661740 100644 --- a/model/llama/model.go +++ b/model/llama/model.go @@ -80,7 +80,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - kq := k.Mulmat(ctx, q) + kq := k.MulmatFullPrec(ctx, q) kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) kq = kq.Softmax(ctx) diff --git a/model/mllama/model_text.go b/model/mllama/model_text.go index 51b4fe918..2b05a60ea 100644 --- a/model/mllama/model_text.go +++ b/model/mllama/model_text.go @@ -37,7 +37,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scores := key.Mulmat(ctx, query) + scores := key.MulmatFullPrec(ctx, query) scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) if mask != nil {