ollama/kvcache/wrapper.go
Jesse Gross ed443a0393 Runner for Ollama engine
This provides integration with the new Ollama engine
(5824541 next ollama runner (#7913)) and the rest of the Ollama
infrastructure such as the runner and Ollama server.

In addition, it also builds out the KV cache infrastructure to
support requirements of how Ollama runs models such as:
 - Parallel processing
 - Memory management for defragmentation and shifting
 - Multi-modal modals

Both old and new engines continue to be supported. By default, only
the old engine is used. To enable the new engine:

Start the server with the OLLAMA_NEW_ENGINE environment variable set:
OLLAMA_NEW_ENGINE=1 ./ollama serve

Start a model that is supported by the Ollama engine. This one is Llama 3.1 8b Q4_K_M:
./ollama run jessegross/llama3.1
2025-02-13 17:09:26 -08:00

94 lines
2.1 KiB
Go

package kvcache
import (
"math"
"github.com/ollama/ollama/ml"
)
// Wrapper cache is a container for multiple types of caches,
// such as for the encoding and decoding portions of a model.
type WrapperCache struct {
// caches we are wrapping
caches []Cache
// cache to be used for this layer
curType int
}
func NewWrapperCache(caches ...Cache) *WrapperCache {
return &WrapperCache{
caches: caches,
}
}
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
for _, cache := range c.caches {
cache.Init(backend, dtype, capacity)
}
}
func (c *WrapperCache) Close() {
for _, cache := range c.caches {
cache.Close()
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, positions, seqs)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
for k := range positions {
_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
}
}
return err
}
}
c.curType = 0
return nil
}
func (c *WrapperCache) SetLayer(layer int) {
for _, cache := range c.caches {
cache.SetLayer(layer)
}
}
func (c *WrapperCache) SetLayerType(layerType int) {
c.curType = layerType
}
func (c *WrapperCache) UnderlyingCache() Cache {
return c.caches[c.curType]
}
func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
return c.caches[c.curType].Get(ctx)
}
func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.caches[c.curType].Put(ctx, key, value)
}
func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
for _, cache := range c.caches {
cache.CopyPrefix(srcSeq, dstSeq, len)
}
}
func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
// If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
for _, cache := range c.caches {
err := cache.Remove(seq, beginIndex, endIndex)
if err != nil {
return err
}
}
return nil
}