diff --git a/ml/backend.go b/ml/backend.go index 3cc18f2b6..6e3f0516f 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -111,6 +111,26 @@ type Tensor interface { Copy(ctx Context, t2 Tensor) Tensor } +// ScaledDotProductAttention implements a fused attention +// operation equivalent to following code on a tensor named +// query: +// +// kq := key.MulmatFullPrec(ctx, query) +// +// kq = kq.Scale(ctx, scale) +// +// if mask != nil { +// kq = kq.Add(ctx, mask) +// } +// +// kq = kq.Softmax(ctx) +// +// kqv := value.Mulmat(ctx, kq) +// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) +type ScaledDotProductAttention interface { + ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor +} + type number interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 2b7b91894..2d7cf340e 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -651,6 +651,21 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int } } +func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor { + var kqMask *C.struct_ggml_tensor + if mask != nil { + kqMask = mask.(*Tensor).t + } + + kq := key.MulmatFullPrec(ctx, t) + kq = &Tensor{ + t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), + } + + kqv := value.Mulmat(ctx, kq) + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) +} + func (b *Backend) SystemInfo() string { var compiler string switch C.get_compiler() { diff --git a/ml/nn/attention.go b/ml/nn/attention.go new file mode 100644 index 000000000..4f0c9fa14 --- /dev/null +++ b/ml/nn/attention.go @@ -0,0 +1,59 @@ +package nn + +import ( + "fmt" + + "github.com/ollama/ollama/ml" +) + +// Attention implements scaled dot-product attention for transformer models: +// Attention(Q, K, V) = softmax(QK^T/√d_k)V +// +// Parameters: +// - ctx: Context for tensor operations +// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads] +// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads] +// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads] +// - mask: Optional attention mask that is added to the attention score. If +// provided, should broadcast to [seq_len_k, seq_len_q, heads] +// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension +// +// Returns: +// +// Attention output with shape [d_v, heads, seq_len_q] +func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor { + if query.Dim(0) != key.Dim(0) { + panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + } + + if mask != nil && query.Dim(1) != mask.Dim(1) { + panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1))) + } + + if key.Dim(1) != value.Dim(0) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0))) + } + + if mask != nil && key.Dim(1) != mask.Dim(0) { + panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0))) + } + + if key.Dim(2) != value.Dim(2) { + panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + } + + if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { + return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) + } else { + kq := key.MulmatFullPrec(ctx, query) + + kq = kq.Scale(ctx, scale) + if mask != nil { + kq = kq.Add(ctx, mask) + } + kq = kq.Softmax(ctx) + + kqv := value.Mulmat(ctx, kq) + return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + } +} diff --git a/model/models/llama/model.go b/model/models/llama/model.go index e90631fb8..4fe029993 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -86,13 +86,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - kq := k.MulmatFullPrec(ctx, q) - kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - kq = kq.Add(ctx, mask) - kq = kq.Softmax(ctx) - - kqv := v.Mulmat(ctx, kq) - kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + scaleFactor := 1.0 / math.Sqrt(float64(headDim)) + kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, kqv) diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 8ad804cf9..003bf9cbf 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -38,13 +38,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scores := key.MulmatFullPrec(ctx, query) - scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - scores = scores.Add(ctx, mask) - scores = scores.Softmax(ctx) - - attention := value.Mulmat(ctx, scores) - attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + scaleFactor := 1.0 / math.Sqrt(float64(headDim)) + attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) @@ -112,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = ca.QueryNorm.Forward(ctx, query, opts.eps) - var key, value ml.Tensor + var key, value, mask ml.Tensor if crossAttentionStates != nil { numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) @@ -125,19 +120,15 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio cache.Put(ctx, key, value) } else { - key, value, _ = cache.Get(ctx) + key, value, mask = cache.Get(ctx) } query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - scores := key.Mulmat(ctx, query) - scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - scores = scores.Softmax(ctx) - - attention := value.Mulmat(ctx, scores) - attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + scaleFactor := 1.0 / math.Sqrt(float64(headDim)) + attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return ca.Output.Forward(ctx, attention)