mirror of
https://github.com/ollama/ollama.git
synced 2025-03-26 01:32:06 +01:00
Currently the runner computes the kv size needed and creates a cache of that size. This is the context size times number of parallel sequences. Cache implementations can make better decisions about their memory usage, so instead pass in the required capacity, number of sequences and maximum batch size. For now, the causal cache just uses this to compute the size in the same way as before.
276 lines
6.8 KiB
Go
276 lines
6.8 KiB
Go
package ollamarunner
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"math"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/model"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
type InputCache struct {
|
|
// context window size (per slot)
|
|
numCtx int32
|
|
|
|
// does the cache store data or do we need to always send the full input?
|
|
// note that when enabled is false the underlying cache may either be nil
|
|
// or a non-nil dummy that doesn't actually store anything
|
|
enabled bool
|
|
|
|
// individual KV caches
|
|
slots []InputCacheSlot
|
|
|
|
// optimize cache eviction for multiple users
|
|
multiUserCache bool
|
|
|
|
cache kvcache.Cache
|
|
}
|
|
|
|
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
|
numCtx := kvSize / int32(numSlots)
|
|
|
|
if numCtx < 1 {
|
|
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
|
}
|
|
|
|
slots := make([]InputCacheSlot, numSlots)
|
|
|
|
for i := range slots {
|
|
slots[i] = InputCacheSlot{Id: i}
|
|
}
|
|
|
|
cache := model.Config().Cache
|
|
if cache != nil {
|
|
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
|
|
}
|
|
|
|
return &InputCache{
|
|
numCtx: numCtx,
|
|
enabled: cache != nil,
|
|
slots: slots,
|
|
multiUserCache: multiUserCache,
|
|
cache: cache,
|
|
}, nil
|
|
}
|
|
|
|
func kvCacheTypeFromStr(s string) ml.DType {
|
|
switch s {
|
|
case "q8_0":
|
|
return ml.DTypeQ80
|
|
case "q4_0":
|
|
return ml.DTypeQ40
|
|
default:
|
|
return ml.DTypeF16
|
|
}
|
|
}
|
|
|
|
func (c *InputCache) Close() {
|
|
c.cache.Close()
|
|
}
|
|
|
|
// Locking: Operations on InputCacheSlot (including finding one
|
|
// through LoadCacheSlot) require a lock to be be held that serializes
|
|
// these operations with each other and processBatch
|
|
|
|
type InputCacheSlot struct {
|
|
// Index in the KV cache
|
|
Id int
|
|
|
|
// Inputs that are stored in the KV cache
|
|
Inputs []input.Input
|
|
|
|
// is this cache actively being processed as part of a sequence?
|
|
InUse bool
|
|
|
|
// last time this cache was used (as of start of processing)
|
|
lastUsed time.Time
|
|
}
|
|
|
|
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
|
var slot *InputCacheSlot
|
|
var numPast int32
|
|
var err error
|
|
|
|
// In single-user scenarios, the longest cache slot works fine for getting good input
|
|
// cache hit rates and it keeps the footprint of the cache small, which improves throughput.
|
|
// For multiple users, the "best" cache slot produces better input cache hit rates
|
|
// at the cost of worse performance when we miss the input cache.
|
|
if !c.multiUserCache {
|
|
slot, numPast, err = c.findLongestCacheSlot(prompt)
|
|
} else {
|
|
slot, numPast, err = c.findBestCacheSlot(prompt)
|
|
}
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
slot.InUse = true
|
|
slot.lastUsed = time.Now()
|
|
|
|
if numPast == int32(len(prompt)) {
|
|
// Leave one input to sample so we can get a response
|
|
numPast--
|
|
}
|
|
|
|
if c.cache != nil {
|
|
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
|
|
if err != nil {
|
|
// Some models don't support partial erasure
|
|
err = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
numPast = 0
|
|
}
|
|
}
|
|
|
|
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
|
|
"used", numPast, "remaining", int32(len(prompt))-numPast)
|
|
|
|
prompt = prompt[numPast:]
|
|
slot.Inputs = slot.Inputs[:numPast]
|
|
|
|
return slot, prompt, nil
|
|
}
|
|
|
|
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
|
longest := int32(-1)
|
|
var longestSlot *InputCacheSlot
|
|
|
|
for i, s := range c.slots {
|
|
if s.InUse {
|
|
continue
|
|
}
|
|
|
|
count := countCommonPrefix(s.Inputs, prompt)
|
|
if count > longest {
|
|
longest = count
|
|
longestSlot = &c.slots[i]
|
|
}
|
|
}
|
|
|
|
if longestSlot == nil {
|
|
return nil, 0, errors.New("no available cache slots")
|
|
}
|
|
|
|
return longestSlot, longest, nil
|
|
}
|
|
|
|
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
|
oldest := time.Now()
|
|
var oldestSlot *InputCacheSlot
|
|
|
|
longest := int32(-1)
|
|
var longestSlot *InputCacheSlot
|
|
|
|
for i, s := range c.slots {
|
|
count := countCommonPrefix(s.Inputs, prompt)
|
|
if count > longest {
|
|
longest = count
|
|
longestSlot = &c.slots[i]
|
|
}
|
|
|
|
if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
|
|
oldest = s.lastUsed
|
|
oldestSlot = &c.slots[i]
|
|
}
|
|
}
|
|
|
|
if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse {
|
|
return longestSlot, longest, nil
|
|
}
|
|
|
|
if oldestSlot.InUse {
|
|
return nil, 0, errors.New("no available cache slots")
|
|
}
|
|
|
|
if len(oldestSlot.Inputs) != 0 {
|
|
slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
|
|
"used", oldestSlot.lastUsed)
|
|
}
|
|
|
|
if longest > 0 && longestSlot != oldestSlot {
|
|
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
|
len(longestSlot.Inputs))
|
|
oldestSlot.Inputs = make([]input.Input, longest)
|
|
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
|
if c.cache != nil {
|
|
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
|
}
|
|
}
|
|
|
|
return oldestSlot, longest, nil
|
|
}
|
|
|
|
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
|
var count int32
|
|
|
|
for i := range a {
|
|
if i >= len(b) {
|
|
break
|
|
}
|
|
|
|
if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
|
|
break
|
|
}
|
|
|
|
count++
|
|
}
|
|
|
|
return count
|
|
}
|
|
|
|
func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
|
targetFree := (c.numCtx - numKeep) / 2
|
|
targetFree = max(targetFree, 1)
|
|
|
|
currentFree := c.numCtx - inputLen
|
|
discard := targetFree - currentFree
|
|
|
|
if discard < 0 {
|
|
discard = 0
|
|
}
|
|
|
|
return discard
|
|
}
|
|
|
|
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
|
// the newest half into that space (saving numKeep inputs at the beginning).
|
|
//
|
|
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
|
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
|
if numKeep >= c.numCtx {
|
|
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
|
}
|
|
|
|
inputLen := int32(len(slot.Inputs))
|
|
discard := c.ShiftDiscard(inputLen, numKeep)
|
|
|
|
if discard <= 0 {
|
|
return nil
|
|
}
|
|
|
|
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
|
"keep", numKeep, "discard", discard)
|
|
|
|
// TODO (jessegross): KV cache removal can fail for certain types of models
|
|
if c.cache != nil {
|
|
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
|
|
}
|
|
}
|
|
|
|
for i := numKeep + discard; i < inputLen; i++ {
|
|
slot.Inputs[i-discard] = slot.Inputs[i]
|
|
}
|
|
slot.Inputs = slot.Inputs[:inputLen-discard]
|
|
|
|
return nil
|
|
}
|