diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index af48ff22e..30292f641 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -225,6 +225,8 @@ func countCommonPrefix(a []input.Input, b []input.Input) int32 { return count } +// TODO(jessegross): If we need to reprocess the inputs we should ensure that +// we don't split up a SameBatch func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { targetFree := (c.numCtx - numKeep) / 2 targetFree = max(targetFree, 1) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 458387184..f3286abae 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -115,16 +115,41 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe params.numKeep = int32(len(inputs)) } - // TODO(jessegross): We should ensure that we always leave minBatch of context space to shift, - // otherwise we might truncate or split the batch against the model's wishes - // Ensure that at least 1 input can be discarded during shift params.numKeep = min(params.numKeep, s.cache.numCtx-1) if int32(len(inputs)) > s.cache.numCtx { discard := int32(len(inputs)) - s.cache.numCtx + promptStart := params.numKeep + discard + + // If we need to truncate in the middle of a unbreakable batch, remove the entire batch + sameBatch := 0 + for i, inp := range inputs { + if sameBatch > 0 { + sameBatch-- + + if promptStart == int32(i) { + promptStart++ + } + } else if promptStart == int32(i) { + break + } + + if inp.SameBatch != 0 { + if int32(i) < params.numKeep { + return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch) + } + + sameBatch = inp.SameBatch + } + } + + if promptStart >= int32(len(inputs)) { + return nil, errors.New("entire prompt removed by truncation") + } + newInputs := inputs[:params.numKeep] - newInputs = append(newInputs, inputs[params.numKeep+discard:]...) + newInputs = append(newInputs, inputs[promptStart:]...) slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs)) inputs = newInputs