llamarunner: clear cache when shift is not possible

Clear KV cache when shift operation is not supported by model.
Added KvCacheCanShift() check to handle models that can't perform cache shifts,
falling back to full cache clear while preserving logical token history to
maintain expected behavior when context window fills up.
This commit is contained in:
Bruce MacDonald 2025-02-28 16:24:48 -08:00
parent d8a5d96b98
commit eb92a726c9
2 changed files with 33 additions and 10 deletions

View File

@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
C.llama_kv_cache_defrag(c.c)
}
func (c *Context) KvCacheCanShift() bool {
return bool(C.llama_kv_cache_can_shift(c.c))
}
// Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))

View File

@ -213,8 +213,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
return discard
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
@ -231,16 +231,35 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
}
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
if c.lc.KvCacheCanShift() {
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
}
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
for i := numKeep + discard; i < len(slot.Inputs); i++ {
slot.Inputs[i-discard] = slot.Inputs[i]
for i := numKeep + discard; i < len(slot.Inputs); i++ {
slot.Inputs[i-discard] = slot.Inputs[i]
}
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
} else {
slog.Debug("kv cache cannot shift, clearing cache and truncating history")
// Clear the entire KV cache
if !c.lc.KvCacheSeqRm(slot.Id, 0, -1) {
return fmt.Errorf("unable to remove kv cache entries (id: %v)", slot.Id)
}
// Update the slot.Inputs to match what would happen with a shift operation
// Keep the first numKeep tokens and the tokens after the discard
keepInputs := make([]input, numKeep)
copy(keepInputs, slot.Inputs[:numKeep])
afterDiscardInputs := make([]input, len(slot.Inputs)-(numKeep+discard))
copy(afterDiscardInputs, slot.Inputs[numKeep+discard:])
// Update the inputs to match what would happen after a shift
slot.Inputs = append(keepInputs, afterDiscardInputs...)
}
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
return nil
}