From 6544e1473525c381e89aba4778283900b3ad7145 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Sat, 11 Oct 2025 16:06:14 -0700 Subject: [PATCH] Reapply "add truncate and shift parameters" (#12582) --- api/types.go | 16 ++++++ llm/server.go | 6 +- runner/llamarunner/runner.go | 22 +++++++ runner/ollamarunner/runner.go | 25 ++++++++ server/prompt.go | 4 +- server/prompt_test.go | 102 +++++++++++++++++++++------------ server/routes.go | 79 ++++++++++++++++++++----- server/routes_generate_test.go | 101 ++++++++++++++++++++++++++++++++ 8 files changed, 298 insertions(+), 57 deletions(-) diff --git a/api/types.go b/api/types.go index d0669b90da..41b490b512 100644 --- a/api/types.go +++ b/api/types.go @@ -106,6 +106,14 @@ type GenerateRequest struct { // before this option was introduced) Think *ThinkValue `json:"think,omitempty"` + // Truncate is a boolean that, when set to true, truncates the chat history messages + // if the rendered prompt exceeds the context length limit. + Truncate *bool `json:"truncate,omitempty"` + + // Shift is a boolean that, when set to true, shifts the chat history + // when hitting the context length limit instead of erroring. + Shift *bool `json:"shift,omitempty"` + // DebugRenderOnly is a debug option that, when set to true, returns the rendered // template instead of calling the model. DebugRenderOnly bool `json:"_debug_render_only,omitempty"` @@ -140,6 +148,14 @@ type ChatRequest struct { // for supported models. Think *ThinkValue `json:"think,omitempty"` + // Truncate is a boolean that, when set to true, truncates the chat history messages + // if the rendered prompt exceeds the context length limit. + Truncate *bool `json:"truncate,omitempty"` + + // Shift is a boolean that, when set to true, shifts the chat history + // when hitting the context length limit instead of erroring. + Shift *bool `json:"shift,omitempty"` + // DebugRenderOnly is a debug option that, when set to true, returns the rendered // template instead of calling the model. DebugRenderOnly bool `json:"_debug_render_only,omitempty"` diff --git a/llm/server.go b/llm/server.go index 079ecb2575..febcc1e65f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -1379,7 +1379,9 @@ type CompletionRequest struct { Images []ImageData Options *api.Options - Grammar string // set before sending the request to the subprocess + Grammar string // set before sending the request to the subprocess + Shift bool + Truncate bool } // DoneReason represents the reason why a completion response is done @@ -1501,7 +1503,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("failed reading llm error response: %w", err) } log.Printf("llm predict error: %s", bodyBytes) - return fmt.Errorf("%s", bodyBytes) + return api.StatusError{StatusCode: res.StatusCode, ErrorMessage: strings.TrimSpace(string(bodyBytes))} } scanner := bufio.NewScanner(res.Body) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 7ed7ebb2bd..a65040d1d7 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -79,6 +79,9 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool + // shift if context window is exceeded + shift bool + doneReason llm.DoneReason // Metrics @@ -94,8 +97,12 @@ type NewSequenceParams struct { numKeep int samplingParams *llama.SamplingParams embedding bool + shift bool + truncate bool } +var errorInputTooLong = errors.New("the input length exceeds the context length") + func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() @@ -119,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe if len(inputs) > s.cache.numCtx { discard := len(inputs) - s.cache.numCtx + if !params.truncate { + return nil, errorInputTooLong + } + newInputs := inputs[:params.numKeep] newInputs = append(newInputs, inputs[params.numKeep+discard:]...) @@ -385,6 +396,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) for i, input := range seq.inputs { if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx { if len(seq.pendingInputs) == 0 { + if !seq.shift { + s.removeSequence(seqIdx, llm.DoneReasonLength) + break + } + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { var reprocess *ErrReprocessInputs @@ -583,8 +599,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { numKeep: req.Options.NumKeep, samplingParams: &samplingParams, embedding: false, + shift: req.Shift, + truncate: req.Truncate, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 28d9d2c9c1..af212eceff 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -88,6 +88,9 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool + // shift if context window is exceeded + shift bool + doneReason llm.DoneReason // Metrics @@ -104,8 +107,12 @@ type NewSequenceParams struct { numKeep int32 sampler sample.Sampler embedding bool + shift bool + truncate bool } +var errorInputTooLong = errors.New("the input length exceeds the context length") + func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() @@ -125,6 +132,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe if int32(len(inputs)) > s.cache.numCtx { discard := int32(len(inputs)) - s.cache.numCtx + + if !params.truncate { + return nil, errorInputTooLong + } + promptStart := params.numKeep + discard // If we need to truncate in the middle of a unbreakable batch, remove the entire batch @@ -176,6 +188,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe embeddingOnly: params.embedding, stop: params.stop, numKeep: params.numKeep, + shift: params.shift, }, nil } @@ -517,6 +530,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er break } + if !seq.shift { + s.removeSequence(seqIdx, llm.DoneReasonLength) + nextBatch.seqs[seqIdx] = nil + break + } + err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { var reprocess *ErrReprocessInputs @@ -832,8 +851,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { numKeep: int32(req.Options.NumKeep), sampler: sampler, embedding: false, + shift: req.Shift, + truncate: req.Truncate, }) if err != nil { + if errors.Is(err, errorInputTooLong) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) return } diff --git a/server/prompt.go b/server/prompt.go index 56bc63030b..2175919821 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error) // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // latest message and 2) system messages -func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (prompt string, images []llm.ImageData, _ error) { +func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent @@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } } - if ctxLen > opts.NumCtx { + if truncate && ctxLen > opts.NumCtx { slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) break } else { diff --git a/server/prompt_test.go b/server/prompt_test.go index 659e64084c..3bd621152b 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -27,16 +27,18 @@ func TestChatPrompt(t *testing.T) { visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} cases := []struct { - name string - model Model - limit int - msgs []api.Message + name string + model Model + limit int + truncate bool + msgs []api.Message expect }{ { - name: "messages", - model: visionModel, - limit: 64, + name: "messages", + model: visionModel, + limit: 64, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -47,9 +49,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages", - model: visionModel, - limit: 1, + name: "truncate messages", + model: visionModel, + limit: 1, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -60,9 +63,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages with image", - model: visionModel, - limit: 64, + name: "truncate messages with image", + model: visionModel, + limit: 64, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -76,9 +80,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate messages with images", - model: visionModel, - limit: 64, + name: "truncate messages with images", + model: visionModel, + limit: 64, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -92,9 +97,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "messages with images", - model: visionModel, - limit: 2048, + name: "messages with images", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -109,9 +115,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "message with image tag", - model: visionModel, - limit: 2048, + name: "message with image tag", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -126,9 +133,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "messages with interleaved images", - model: visionModel, - limit: 2048, + name: "messages with interleaved images", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Images: []api.ImageData{[]byte("something")}}, @@ -145,9 +153,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "truncate message with interleaved images", - model: visionModel, - limit: 1024, + name: "truncate message with interleaved images", + model: visionModel, + limit: 1024, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "user", Images: []api.ImageData{[]byte("something")}}, @@ -163,9 +172,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "message with system prompt", - model: visionModel, - limit: 2048, + name: "message with system prompt", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "system", Content: "You are the Test Who Lived."}, {Role: "user", Content: "You're a test, Harry!"}, @@ -177,9 +187,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "out of order system", - model: visionModel, - limit: 2048, + name: "out of order system", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "You're a test, Harry!"}, {Role: "assistant", Content: "I-I'm a what?"}, @@ -191,9 +202,10 @@ func TestChatPrompt(t *testing.T) { }, }, { - name: "multiple images same prompt", - model: visionModel, - limit: 2048, + name: "multiple images same prompt", + model: visionModel, + limit: 2048, + truncate: true, msgs: []api.Message{ {Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}}, }, @@ -202,6 +214,20 @@ func TestChatPrompt(t *testing.T) { images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")}, }, }, + { + name: "no truncate with limit exceeded", + model: visionModel, + limit: 10, + truncate: false, + msgs: []api.Message{ + {Role: "user", Content: "You're a test, Harry!"}, + {Role: "assistant", Content: "I-I'm a what?"}, + {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, + }, + expect: expect{ + prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ", + }, + }, } for _, tt := range cases { @@ -209,7 +235,7 @@ func TestChatPrompt(t *testing.T) { model := tt.model opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} think := false - prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}) + prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate) if tt.error == nil && err != nil { t.Fatal(err) } else if tt.error != nil && err != tt.error { diff --git a/server/routes.go b/server/routes.go index 7b25b09c04..e65b0ed7d2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -434,7 +434,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { // the real chat handler, but doing this as a stopgap to get renderer // support for generate if values.Messages != nil && values.Suffix == "" && req.Template == "" { - prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think) + prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -488,10 +488,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + Shift: req.Shift == nil || *req.Shift, + Truncate: req.Truncate == nil || *req.Truncate, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, @@ -553,7 +555,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { ch <- res }); err != nil { - ch <- gin.H{"error": err.Error()} + var serr api.StatusError + if errors.As(err, &serr) { + ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode} + } else { + ch <- gin.H{"error": err.Error()} + } } }() @@ -573,7 +580,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { msg = "unexpected error format in response" } - c.JSON(http.StatusInternalServerError, gin.H{"error": msg}) + status, ok := t["status"].(int) + if !ok { + status = http.StatusInternalServerError + } + + c.JSON(status, gin.H{"error": msg}) return default: c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) @@ -1638,6 +1650,30 @@ func streamResponse(c *gin.Context, ch chan any) { return false } + // errors are provided as a gin.H with an "error" field and + // an optional "status" field. For errors that are streamed + // before any content, we need to set the status code and + // content type for the error. + if h, ok := val.(gin.H); ok { + if e, ok := h["error"].(string); ok { + status, ok := h["status"].(int) + if !ok { + status = http.StatusInternalServerError + } + + if !c.Writer.Written() { + c.Header("Content-Type", "application/json") + c.JSON(status, gin.H{"error": e}) + } else { + if err := json.NewEncoder(c.Writer).Encode(gin.H{"error": e}); err != nil { + slog.Error("streamResponse failed to encode json error", "error", err) + } + } + + return false + } + } + bts, err := json.Marshal(val) if err != nil { slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err)) @@ -1957,7 +1993,8 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think) + truncate := req.Truncate == nil || *req.Truncate + prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) if err != nil { slog.Error("chat prompt error", "error", err) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -2034,10 +2071,12 @@ func (s *Server) ChatHandler(c *gin.Context) { // sets up new context given parent context per request ctx, cancel := context.WithCancel(c.Request.Context()) err := r.Completion(ctx, llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: currentFormat, - Options: opts, + Prompt: prompt, + Images: images, + Format: currentFormat, + Options: opts, + Shift: req.Shift == nil || *req.Shift, + Truncate: truncate, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, @@ -2131,7 +2170,12 @@ func (s *Server) ChatHandler(c *gin.Context) { if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil { // only ignores error if it's a context cancellation due to setting structured outputs } else { - ch <- gin.H{"error": err.Error()} + var serr api.StatusError + if errors.As(err, &serr) { + ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode} + } else { + ch <- gin.H{"error": err.Error()} + } return } } @@ -2145,7 +2189,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = append(msgs, msg) - prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think) + prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) if err != nil { slog.Error("chat prompt error applying structured outputs", "error", err) ch <- gin.H{"error": err.Error()} @@ -2185,7 +2229,12 @@ func (s *Server) ChatHandler(c *gin.Context) { msg = "unexpected error format in response" } - c.JSON(http.StatusInternalServerError, gin.H{"error": msg}) + status, ok := t["status"].(int) + if !ok { + status = http.StatusInternalServerError + } + + c.JSON(status, gin.H{"error": msg}) return default: c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index f8dc70d64f..a86a70ba50 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -609,6 +609,58 @@ func TestGenerateChat(t *testing.T) { t.Errorf("final tool call mismatch (-got +want):\n%s", diff) } }) + + t.Run("status error non-streaming", func(t *testing.T) { + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + return api.StatusError{ + StatusCode: http.StatusServiceUnavailable, + Status: "Service Unavailable", + ErrorMessage: "model is overloaded", + } + } + + stream := false + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + Stream: &stream, + }) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("expected status 503, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("status error streaming", func(t *testing.T) { + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + return api.StatusError{ + StatusCode: http.StatusTooManyRequests, + Status: "Too Many Requests", + ErrorMessage: "rate limit exceeded", + } + } + + w := createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "test", + Messages: []api.Message{ + {Role: "user", Content: "Hello!"}, + }, + }) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected status 429, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) } func TestGenerate(t *testing.T) { @@ -983,6 +1035,55 @@ func TestGenerate(t *testing.T) { t.Errorf("mismatch (-got +want):\n%s", diff) } }) + + t.Run("status error non-streaming", func(t *testing.T) { + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + return api.StatusError{ + StatusCode: http.StatusServiceUnavailable, + Status: "Service Unavailable", + ErrorMessage: "model is overloaded", + } + } + + streamRequest := false + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "Hello!", + Stream: &streamRequest, + }) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("expected status 503, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + + t.Run("status error streaming", func(t *testing.T) { + mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { + return api.StatusError{ + StatusCode: http.StatusTooManyRequests, + Status: "Too Many Requests", + ErrorMessage: "rate limit exceeded", + } + } + + w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ + Model: "test", + Prompt: "Hello!", + Stream: &stream, + }) + + if w.Code != http.StatusTooManyRequests { + t.Errorf("expected status 429, got %d", w.Code) + } + + if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) } func TestChatWithPromptEndingInThinkTag(t *testing.T) {