From 1ca8eb5c05e95f109e158868712ffbeeb84c5a48 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 1 Apr 2025 13:57:51 -0700 Subject: [PATCH] use fast attention --- model/models/mistral3/model_vision.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 26120afcc..9ed51632c 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -36,18 +36,10 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize) value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize) - query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) - scores := key.Mulmat(ctx, query) - scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(opts.headDim))) - scores = scores.Softmax(ctx) - - attention := value.Mulmat(ctx, scores).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) return sa.Output.Forward(ctx, attention) } @@ -166,6 +158,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { positionEmbedding := m.positionalEmbedding(ctx, positionIDs) cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) + cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1)) + sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1)) for _, layer := range m.Layers { hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)