From 9c23f11850e4df183ed92750306c30ef0e480e13 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Mon, 10 Mar 2025 21:24:06 -0700 Subject: [PATCH] pr feedback --- runner/llamarunner/cache.go | 41 ++++++++++++++++++++--------------- runner/ollamarunner/cache.go | 1 - runner/ollamarunner/runner.go | 3 ++- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/runner/llamarunner/cache.go b/runner/llamarunner/cache.go index e88aac5a8..55bceca33 100644 --- a/runner/llamarunner/cache.go +++ b/runner/llamarunner/cache.go @@ -230,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) } - discard := c.ShiftDiscard(len(slot.Inputs), numKeep) + inputLen := len(slot.Inputs) + discard := c.ShiftDiscard(inputLen, numKeep) if discard <= 0 { return nil @@ -239,37 +240,43 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error { slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) + var shiftFailed bool + if c.lc.KvCacheCanShift() { + // For models that support shifting, attempt to shift the KV cache if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) { - return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard) + shiftFailed = true + slog.Debug("kv cache removal failed, clearing cache and returning inputs for reprocessing", "id", slot.Id) + } else { + c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard) } - c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard) - - for i := numKeep + discard; i < len(slot.Inputs); i++ { - slot.Inputs[i-discard] = slot.Inputs[i] - } - slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard] } else { - slog.Debug("kv cache cannot shift, clearing cache and truncating history") + // For models that don't support shifting + shiftFailed = true + slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id) + } + if shiftFailed { // Clear the entire KV cache - if !c.lc.KvCacheSeqRm(slot.Id, 0, -1) { - return fmt.Errorf("unable to remove kv cache entries (id: %v)", slot.Id) - } + _ = c.lc.KvCacheSeqRm(slot.Id, 0, -1) - // Update the slot.Inputs to match what would happen with a shift operation - // Keep the first numKeep tokens and the tokens after the discard - newInputs := make([]input, numKeep+len(slot.Inputs)-(numKeep+discard)) + // Create new input slice with preserved tokens (numKeep + remaining tokens after discard) + newInputs := make([]input, numKeep+inputLen-(numKeep+discard)) copy(newInputs[:numKeep], slot.Inputs[:numKeep]) copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) // Reset the slot inputs since we've cleared the cache slot.Inputs = []input{} - // Return the inputs that need to be reprocessed - // The caller will need to prepend these to the sequence's inputs queue + // Return error with inputs that need to be reprocessed return &ErrReprocessInputs{Inputs: newInputs} } + // Standard shift succeeded - update input array + for i := numKeep + discard; i < inputLen; i++ { + slot.Inputs[i-discard] = slot.Inputs[i] + } + slot.Inputs = slot.Inputs[:inputLen-discard] + return nil } diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index c0496a547..b6a011cf6 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -268,7 +268,6 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), "keep", numKeep, "discard", discard) - // TODO (jessegross): KV cache removal can fail for certain types of models if c.cache != nil { err := c.cache.Remove(slot.Id, numKeep, numKeep+discard) if err != nil { diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 11eeb4153..6d86c1511 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -360,7 +360,8 @@ func (s *Server) processBatch() error { if errors.As(err, &reprocess) { // Prepend these inputs to the sequence's inputs queue for reprocessing seq.inputs = append(reprocess.Inputs, seq.inputs...) - // Continue processing as normal + // Return early to restart processing with the new inputs at the beginning + return nil } else { return err }