From e53b3cbd0c3f08eb692a318c8eaf687a01c2e8c0 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 3 Apr 2025 10:19:24 -0700 Subject: [PATCH] llm: set done reason at server level (#9830) No functional change. Many different done reasons can be set at the runner level, so rather than obsuring them we should return them to the server process and let it choose what to do with the done reason. This separates the API concerns from the runner. --- llm/server.go | 26 ++++++++++++++++++++++++-- runner/llamarunner/runner.go | 21 ++++++++------------- runner/ollamarunner/runner.go | 21 ++++++++------------- server/routes.go | 20 ++++++++++---------- server/routes_generate_test.go | 8 ++++---- 5 files changed, 54 insertions(+), 42 deletions(-) diff --git a/llm/server.go b/llm/server.go index e6046db60..a2bc1548f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -675,9 +675,32 @@ type CompletionRequest struct { Grammar string // set before sending the request to the subprocess } +// DoneReason represents the reason why a completion response is done +type DoneReason int + +const ( + // DoneReasonStop indicates the completion stopped naturally + DoneReasonStop DoneReason = iota + // DoneReasonLength indicates the completion stopped due to length limits + DoneReasonLength + // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed + DoneReasonConnectionClosed +) + +func (d DoneReason) String() string { + switch d { + case DoneReasonLength: + return "length" + case DoneReasonStop: + return "stop" + default: + return "" // closed + } +} + type CompletionResponse struct { Content string `json:"content"` - DoneReason string `json:"done_reason"` + DoneReason DoneReason `json:"done_reason"` Done bool `json:"done"` PromptEvalCount int `json:"prompt_eval_count"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"` @@ -786,7 +809,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu continue } - // slog.Debug("got line", "line", string(line)) evt, ok := bytes.CutPrefix(line, []byte("data: ")) if !ok { evt = line diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index a4264f5fc..d8169be40 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -83,7 +83,7 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool - doneReason string + doneReason llm.DoneReason // Metrics startProcessingTime time.Time @@ -301,7 +301,7 @@ func flushPending(seq *Sequence) bool { } } -func (s *Server) removeSequence(seqIndex int, reason string) { +func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] flushPending(seq) @@ -380,7 +380,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(seqIdx, llm.DoneReasonLength) continue } @@ -482,7 +482,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } seq.embedding <- embed - s.removeSequence(i, "") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -499,7 +499,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) // as it's important for the /api/generate context // seq.responses <- piece - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -530,7 +530,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } seq.cache.Inputs = seq.cache.Inputs[:tokenLen] - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -543,7 +543,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) } if !flushPending(seq) { - s.removeSequence(i, "connection") + s.removeSequence(i, llm.DoneReasonConnectionClosed) } } @@ -657,14 +657,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - // Send the final response - doneReason := "stop" - if seq.doneReason == "limit" { - doneReason = "length" - } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Done: true, - DoneReason: doneReason, + DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), EvalCount: seq.numDecoded, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index f3286abae..7b7e09402 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -82,7 +82,7 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool - doneReason string + doneReason llm.DoneReason // Metrics startProcessingTime time.Time @@ -341,7 +341,7 @@ func flushPending(seq *Sequence) bool { } } -func (s *Server) removeSequence(seqIndex int, reason string) { +func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] flushPending(seq) @@ -391,7 +391,7 @@ func (s *Server) processBatch() error { // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(seqIdx, llm.DoneReasonLength) continue } @@ -510,7 +510,7 @@ func (s *Server) processBatch() error { if seq.embeddingOnly { // TODO(jessegross): Embedding support slog.Warn("generation of embedding outputs not yet supported") - s.removeSequence(i, "") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -528,7 +528,7 @@ func (s *Server) processBatch() error { // as it's important for the /api/generate context // seq.responses <- piece - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -564,7 +564,7 @@ func (s *Server) processBatch() error { } seq.cache.Inputs = seq.cache.Inputs[:tokenLen] - s.removeSequence(i, "stop") + s.removeSequence(i, llm.DoneReasonStop) continue } @@ -577,7 +577,7 @@ func (s *Server) processBatch() error { } if !flushPending(seq) { - s.removeSequence(i, "connection") + s.removeSequence(i, llm.DoneReasonConnectionClosed) } } @@ -690,14 +690,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - // Send the final response - doneReason := "stop" - if seq.doneReason == "limit" { - doneReason = "length" - } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Done: true, - DoneReason: doneReason, + DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), EvalCount: seq.numPredicted, diff --git a/server/routes.go b/server/routes.go index eee34033e..906426b18 100644 --- a/server/routes.go +++ b/server/routes.go @@ -308,11 +308,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { Options: opts, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Response: cr.Content, - Done: cr.Done, - DoneReason: cr.DoneReason, + Model: req.Model, + CreatedAt: time.Now().UTC(), + Response: cr.Content, + Done: cr.Done, Metrics: api.Metrics{ PromptEvalCount: cr.PromptEvalCount, PromptEvalDuration: cr.PromptEvalDuration, @@ -326,6 +325,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { + res.DoneReason = cr.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -1533,11 +1533,10 @@ func (s *Server) ChatHandler(c *gin.Context) { Options: opts, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, - Done: r.Done, - DoneReason: r.DoneReason, + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant", Content: r.Content}, + Done: r.Done, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, PromptEvalDuration: r.PromptEvalDuration, @@ -1547,6 +1546,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } if r.Done { + res.DoneReason = r.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index aa263bf97..f219387c3 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -58,7 +58,7 @@ func TestGenerateChat(t *testing.T) { mock := mockRunner{ CompletionResponse: llm.CompletionResponse{ Done: true, - DoneReason: "stop", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, @@ -401,7 +401,7 @@ func TestGenerateChat(t *testing.T) { mock.CompletionResponse = llm.CompletionResponse{ Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`, Done: true, - DoneReason: "done", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, @@ -519,7 +519,7 @@ func TestGenerateChat(t *testing.T) { { Content: `, WA","unit":"celsius"}}`, Done: true, - DoneReason: "tool_call", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 3, PromptEvalDuration: 1, }, @@ -594,7 +594,7 @@ func TestGenerate(t *testing.T) { mock := mockRunner{ CompletionResponse: llm.CompletionResponse{ Done: true, - DoneReason: "stop", + DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1,