mirror of
https://github.com/ollama/ollama.git
synced 2025-03-19 14:21:57 +01:00
Prior to performing attention, we need to permute query, key and value. Currently we call Contiguous after each of these permutations, which is correct but expensive. Avoiding the 3 calls to Contiguous increases performance by over 20%. The permutations of query and key do not violate the continuity rules for mulmat and the Contiguous call can be simply removed. Value requires a different permutation and does require Contiguous. However, we can use the copy into the cache as a way to perform this without further overhead. To support this and avoid unexpected tensor shapes that are seen by models, we need tighter integration between attention, cache and backend. Future optimization will also likely need this structure - for example, flash attention has special padding requirements in the cache and other backends may have their own needs. This further contains the operations that go into attention so that these and other optimizations can be handled transparently. Models that have special requirements for attention can still implement their own version of it.
66 lines
2.1 KiB
Go
66 lines
2.1 KiB
Go
package kvcache
|
|
|
|
import (
|
|
"errors"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
)
|
|
|
|
var (
|
|
ErrKvCacheFull = errors.New("could not find a kv cache slot")
|
|
ErrNotSupported = errors.New("model does not support operation")
|
|
)
|
|
|
|
type Cache interface {
|
|
// ** used by model implementations **
|
|
|
|
// SetLayer sets the active layer of the cache
|
|
SetLayer(layer int)
|
|
|
|
// Get returns the history of key and value tensors plus a mask
|
|
//
|
|
// The shape of the tensors is documented in the specific
|
|
// cache implementation used.
|
|
Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
|
|
|
// Put stores a batch of key and value in the cache
|
|
//
|
|
// The shape of the tensors is documented in the specific
|
|
// cache implementation used.
|
|
Put(ctx ml.Context, key, value ml.Tensor)
|
|
|
|
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
|
// the output of the cache to work better with specific kernels. If not called,
|
|
// the backend settings will be used. This works well when calling Attention.
|
|
//
|
|
// The config can be overridden by models, especially if they require vanilla
|
|
// output when implementing their own version of attention. To do this, pass
|
|
// an empty ml.CacheConfig.
|
|
//
|
|
// Most models will not need to use this.
|
|
SetConfig(ml.CacheConfig)
|
|
|
|
// ** cache management **
|
|
|
|
// Init sets up runtime parameters
|
|
Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
|
|
|
// Close closes the cache and frees resources associated with it
|
|
Close()
|
|
|
|
// StartForward is called before the start of the model's forward pass.
|
|
// For each token in the coming batch, there must be a corresponding
|
|
// entry in positions and seqs.
|
|
StartForward(ctx ml.Context, positions []int32, seqs []int) error
|
|
|
|
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
|
CopyPrefix(srcSeq, dstSeq int, len int32)
|
|
|
|
// Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
|
|
// endIndex to math.MaxInt32 to remove everything starting at beginIndex.
|
|
//
|
|
// If an error occurs, the entire context for the sequence should be
|
|
// removed by calling Remove(seq, 0, math.MaxInt32)
|
|
Remove(seq int, beginIndex, endIndex int32) error
|
|
}
|