mirror of
https://github.com/ollama/ollama.git
synced 2025-04-12 13:49:43 +02:00
use fast attention
This commit is contained in:
parent
8aec3e1374
commit
1ca8eb5c05
@ -36,18 +36,10 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(opts.headDim)))
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
@ -166,6 +158,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
|
||||
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
|
||||
|
Loading…
x
Reference in New Issue
Block a user