mirror of
https://github.com/ollama/ollama.git
synced 2025-03-19 14:21:57 +01:00
Prior to performing attention, we need to permute query, key and value. Currently we call Contiguous after each of these permutations, which is correct but expensive. Avoiding the 3 calls to Contiguous increases performance by over 20%. The permutations of query and key do not violate the continuity rules for mulmat and the Contiguous call can be simply removed. Value requires a different permutation and does require Contiguous. However, we can use the copy into the cache as a way to perform this without further overhead. To support this and avoid unexpected tensor shapes that are seen by models, we need tighter integration between attention, cache and backend. Future optimization will also likely need this structure - for example, flash attention has special padding requirements in the cache and other backends may have their own needs. This further contains the operations that go into attention so that these and other optimizations can be handled transparently. Models that have special requirements for attention can still implement their own version of it.
71 lines
2.3 KiB
Go
71 lines
2.3 KiB
Go
package nn
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
)
|
|
|
|
// Attention implements scaled dot-product attention for transformer models:
|
|
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
|
//
|
|
// Parameters:
|
|
// - ctx: Context for tensor operations
|
|
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
|
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
|
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
|
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
|
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
|
//
|
|
// Returns:
|
|
//
|
|
// Attention output with shape [d_v, heads, seq_len_q]
|
|
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
|
if key != nil && value != nil {
|
|
if query.Dim(0) != key.Dim(0) {
|
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
|
}
|
|
|
|
if key.Dim(1) != value.Dim(1) {
|
|
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
|
}
|
|
|
|
if key.Dim(2) != value.Dim(2) {
|
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
|
}
|
|
|
|
if cache != nil {
|
|
cache.Put(ctx, key, value)
|
|
}
|
|
} else if cache == nil {
|
|
panic("key & value tensors must be provided if cache is nil")
|
|
}
|
|
|
|
var mask ml.Tensor
|
|
if cache != nil {
|
|
key, value, mask = cache.Get(ctx)
|
|
}
|
|
|
|
// Only use the fast SDPA implementation if we have a cache, since that's what
|
|
// will do any expected backend-specific transformations for us
|
|
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
|
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
|
} else {
|
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
kq := key.MulmatFullPrec(ctx, query)
|
|
|
|
kq = kq.Scale(ctx, scale)
|
|
if mask != nil {
|
|
kq = kq.Add(ctx, mask)
|
|
}
|
|
kq = kq.Softmax(ctx)
|
|
|
|
kqv := value.Mulmat(ctx, kq)
|
|
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
}
|
|
}
|