mirror of
https://github.com/ollama/ollama.git
synced 2025-03-22 07:42:49 +01:00
Prior to performing attention, we need to permute query, key and value. Currently we call Contiguous after each of these permutations, which is correct but expensive. Avoiding the 3 calls to Contiguous increases performance by over 20%. The permutations of query and key do not violate the continuity rules for mulmat and the Contiguous call can be simply removed. Value requires a different permutation and does require Contiguous. However, we can use the copy into the cache as a way to perform this without further overhead. To support this and avoid unexpected tensor shapes that are seen by models, we need tighter integration between attention, cache and backend. Future optimization will also likely need this structure - for example, flash attention has special padding requirements in the cache and other backends may have their own needs. This further contains the operations that go into attention so that these and other optimizations can be handled transparently. Models that have special requirements for attention can still implement their own version of it.
100 lines
2.2 KiB
Go
100 lines
2.2 KiB
Go
package kvcache
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
)
|
|
|
|
// Wrapper cache is a container for multiple types of caches,
|
|
// such as for the encoding and decoding portions of a model.
|
|
type WrapperCache struct {
|
|
// caches we are wrapping
|
|
caches []Cache
|
|
|
|
// cache to be used for this layer
|
|
curType int
|
|
}
|
|
|
|
func NewWrapperCache(caches ...Cache) *WrapperCache {
|
|
return &WrapperCache{
|
|
caches: caches,
|
|
}
|
|
}
|
|
|
|
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
for _, cache := range c.caches {
|
|
cache.Init(backend, dtype, capacity)
|
|
}
|
|
}
|
|
|
|
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
|
for _, cache := range c.caches {
|
|
cache.SetConfig(config)
|
|
}
|
|
}
|
|
|
|
func (c *WrapperCache) Close() {
|
|
for _, cache := range c.caches {
|
|
cache.Close()
|
|
}
|
|
}
|
|
|
|
func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
|
for i, cache := range c.caches {
|
|
err := cache.StartForward(ctx, positions, seqs)
|
|
if err != nil {
|
|
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
|
for j := i - 1; j >= 0; j-- {
|
|
for k := range positions {
|
|
_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
c.curType = 0
|
|
return nil
|
|
}
|
|
|
|
func (c *WrapperCache) SetLayer(layer int) {
|
|
for _, cache := range c.caches {
|
|
cache.SetLayer(layer)
|
|
}
|
|
}
|
|
|
|
func (c *WrapperCache) SetLayerType(layerType int) {
|
|
c.curType = layerType
|
|
}
|
|
|
|
func (c *WrapperCache) UnderlyingCache() Cache {
|
|
return c.caches[c.curType]
|
|
}
|
|
|
|
func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
return c.caches[c.curType].Get(ctx)
|
|
}
|
|
|
|
func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
c.caches[c.curType].Put(ctx, key, value)
|
|
}
|
|
|
|
func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|
for _, cache := range c.caches {
|
|
cache.CopyPrefix(srcSeq, dstSeq, len)
|
|
}
|
|
}
|
|
|
|
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
|
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
|
for _, cache := range c.caches {
|
|
err := cache.Remove(seq, beginIndex, endIndex)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|