mirror of
https://github.com/ollama/ollama.git
synced 2025-08-02 14:22:51 +02:00
attention: Remove unnecessary contiguous operations
Prior to performing attention, we need to permute query, key and value. Currently we call Contiguous after each of these permutations, which is correct but expensive. Avoiding the 3 calls to Contiguous increases performance by over 20%. The permutations of query and key do not violate the continuity rules for mulmat and the Contiguous call can be simply removed. Value requires a different permutation and does require Contiguous. However, we can use the copy into the cache as a way to perform this without further overhead. To support this and avoid unexpected tensor shapes that are seen by models, we need tighter integration between attention, cache and backend. Future optimization will also likely need this structure - for example, flash attention has special padding requirements in the cache and other backends may have their own needs. This further contains the operations that go into attention so that these and other optimizations can be handled transparently. Models that have special requirements for attention can still implement their own version of it.
This commit is contained in:
@@ -27,6 +27,27 @@ type Backend interface {
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output the cache to work better with specific kernels.
|
||||
type CacheConfig struct {
|
||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||
// cache itself will also be increased to a multiple of this size if needed.
|
||||
CachePadding int
|
||||
|
||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
@@ -116,6 +137,10 @@ type Tensor interface {
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
|
Reference in New Issue
Block a user