2024-12-17 19:59:41 -08:00
|
|
|
package kvcache
|
|
|
|
|
|
|
|
import (
|
|
|
|
"math"
|
|
|
|
|
|
|
|
"github.com/ollama/ollama/ml"
|
2025-03-08 15:45:31 -08:00
|
|
|
"github.com/ollama/ollama/model/input"
|
2024-12-17 19:59:41 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
// Wrapper cache is a container for multiple types of caches,
|
|
|
|
// such as for the encoding and decoding portions of a model.
|
|
|
|
type WrapperCache struct {
|
|
|
|
// caches we are wrapping
|
|
|
|
caches []Cache
|
|
|
|
|
|
|
|
// cache to be used for this layer
|
|
|
|
curType int
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewWrapperCache(caches ...Cache) *WrapperCache {
|
|
|
|
return &WrapperCache{
|
|
|
|
caches: caches,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
|
|
|
for _, cache := range c.caches {
|
|
|
|
cache.Init(backend, dtype, capacity)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-02-22 21:34:10 -08:00
|
|
|
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
|
|
|
for _, cache := range c.caches {
|
|
|
|
cache.SetConfig(config)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-12-17 19:59:41 -08:00
|
|
|
func (c *WrapperCache) Close() {
|
|
|
|
for _, cache := range c.caches {
|
|
|
|
cache.Close()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-03-08 15:45:31 -08:00
|
|
|
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
|
2024-12-17 19:59:41 -08:00
|
|
|
for i, cache := range c.caches {
|
2025-03-08 15:45:31 -08:00
|
|
|
err := cache.StartForward(ctx, opts)
|
2024-12-17 19:59:41 -08:00
|
|
|
if err != nil {
|
|
|
|
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
|
|
|
for j := i - 1; j >= 0; j-- {
|
2025-03-08 15:45:31 -08:00
|
|
|
for k := range opts.Positions {
|
|
|
|
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
|
2024-12-17 19:59:41 -08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
c.curType = 0
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *WrapperCache) SetLayer(layer int) {
|
|
|
|
for _, cache := range c.caches {
|
|
|
|
cache.SetLayer(layer)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *WrapperCache) SetLayerType(layerType int) {
|
|
|
|
c.curType = layerType
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *WrapperCache) UnderlyingCache() Cache {
|
|
|
|
return c.caches[c.curType]
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
|
|
|
return c.caches[c.curType].Get(ctx)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
|
|
|
c.caches[c.curType].Put(ctx, key, value)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
|
|
|
for _, cache := range c.caches {
|
|
|
|
cache.CopyPrefix(srcSeq, dstSeq, len)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
|
|
|
|
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
|
|
|
|
for _, cache := range c.caches {
|
|
|
|
err := cache.Remove(seq, beginIndex, endIndex)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|