use fast attention

This commit is contained in:
Michael Yang 2025-04-01 13:57:51 -07:00
parent 8aec3e1374
commit 1ca8eb5c05

View File

@ -36,18 +36,10 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(opts.headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
@ -166,6 +158,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)