mirror of
https://github.com/ollama/ollama.git
synced 2025-07-22 02:53:53 +02:00
use fast attention
This commit is contained in:
@ -36,18 +36,10 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
|||||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||||
|
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||||
|
|
||||||
scores := key.Mulmat(ctx, query)
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
||||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(opts.headDim)))
|
|
||||||
scores = scores.Softmax(ctx)
|
|
||||||
|
|
||||||
attention := value.Mulmat(ctx, scores).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||||
return sa.Output.Forward(ctx, attention)
|
return sa.Output.Forward(ctx, attention)
|
||||||
}
|
}
|
||||||
@ -166,6 +158,8 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
|||||||
|
|
||||||
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
||||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||||
|
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||||
|
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||||
|
|
||||||
for _, layer := range m.Layers {
|
for _, layer := range m.Layers {
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
|
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
|
||||||
|
Reference in New Issue
Block a user