2025-02-14 20:51:44 -08:00
|
|
|
package nn
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
|
2025-02-22 21:34:10 -08:00
|
|
|
"github.com/ollama/ollama/kvcache"
|
2025-02-14 20:51:44 -08:00
|
|
|
"github.com/ollama/ollama/ml"
|
|
|
|
)
|
|
|
|
|
|
|
|
// Attention implements scaled dot-product attention for transformer models:
|
|
|
|
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
|
|
|
//
|
|
|
|
// Parameters:
|
|
|
|
// - ctx: Context for tensor operations
|
2025-02-22 21:34:10 -08:00
|
|
|
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
|
|
|
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
|
|
|
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
2025-02-14 20:51:44 -08:00
|
|
|
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
2025-02-22 21:34:10 -08:00
|
|
|
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
2025-02-14 20:51:44 -08:00
|
|
|
//
|
|
|
|
// Returns:
|
|
|
|
//
|
|
|
|
// Attention output with shape [d_v, heads, seq_len_q]
|
2025-02-22 21:34:10 -08:00
|
|
|
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
|
|
|
if key != nil && value != nil {
|
|
|
|
if query.Dim(0) != key.Dim(0) {
|
|
|
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
|
|
|
}
|
2025-02-14 20:51:44 -08:00
|
|
|
|
2025-02-22 21:34:10 -08:00
|
|
|
if key.Dim(1) != value.Dim(1) {
|
|
|
|
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
|
|
|
}
|
2025-02-14 20:51:44 -08:00
|
|
|
|
2025-02-22 21:34:10 -08:00
|
|
|
if key.Dim(2) != value.Dim(2) {
|
|
|
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
|
|
|
}
|
2025-02-14 20:51:44 -08:00
|
|
|
|
2025-02-22 21:34:10 -08:00
|
|
|
if cache != nil {
|
|
|
|
cache.Put(ctx, key, value)
|
|
|
|
}
|
|
|
|
} else if cache == nil {
|
|
|
|
panic("key & value tensors must be provided if cache is nil")
|
2025-02-14 20:51:44 -08:00
|
|
|
}
|
|
|
|
|
2025-02-22 21:34:10 -08:00
|
|
|
var mask ml.Tensor
|
|
|
|
if cache != nil {
|
|
|
|
key, value, mask = cache.Get(ctx)
|
2025-02-14 20:51:44 -08:00
|
|
|
}
|
|
|
|
|
2025-02-22 21:34:10 -08:00
|
|
|
// Only use the fast SDPA implementation if we have a cache, since that's what
|
|
|
|
// will do any expected backend-specific transformations for us
|
|
|
|
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
2025-02-14 20:51:44 -08:00
|
|
|
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
|
|
|
} else {
|
2025-02-22 21:34:10 -08:00
|
|
|
query = query.Permute(ctx, 0, 2, 1, 3)
|
|
|
|
key = key.Permute(ctx, 0, 2, 1, 3)
|
|
|
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
|
2025-02-14 20:51:44 -08:00
|
|
|
kq := key.MulmatFullPrec(ctx, query)
|
|
|
|
|
|
|
|
kq = kq.Scale(ctx, scale)
|
|
|
|
if mask != nil {
|
|
|
|
kq = kq.Add(ctx, mask)
|
|
|
|
}
|
|
|
|
kq = kq.Softmax(ctx)
|
|
|
|
|
|
|
|
kqv := value.Mulmat(ctx, kq)
|
|
|
|
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
|
|
}
|
|
|
|
}
|