From 771c88b3ad9d06053d4eb7fcbe106f1aca5ef36a Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Wed, 19 Mar 2025 10:41:51 -0700 Subject: [PATCH] use done reason enum --- llm/server.go | 35 ++++++++++++++++++++++++++--------- runner/llamarunner/runner.go | 2 +- runner/ollamarunner/runner.go | 2 +- server/routes.go | 4 ++-- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/llm/server.go b/llm/server.go index bdb33759b..7bc13f083 100644 --- a/llm/server.go +++ b/llm/server.go @@ -675,9 +675,34 @@ type CompletionRequest struct { Grammar string // set before sending the request to the subprocess } +// DoneReason represents the reason why a completion response is done +type DoneReason string + +const ( + // DoneReasonStop indicates the completion stopped naturally + DoneReasonStop DoneReason = "stop" + // DoneReasonLength indicates the completion stopped due to length limits + DoneReasonLength DoneReason = "length" +) + +func (d DoneReason) String() string { + return string(d) +} + +// ParseDoneReason converts a string to a DoneReason type +// If the string doesn't match any known reason, it defaults to DoneReasonStop +func ParseDoneReason(reason string) DoneReason { + switch reason { + case "limit", "length": + return DoneReasonLength + default: + return DoneReasonStop + } +} + type CompletionResponse struct { Content string `json:"content"` - DoneReason string `json:"done_reason"` + DoneReason DoneReason `json:"done_reason"` Done bool `json:"done"` PromptEvalCount int `json:"prompt_eval_count"` PromptEvalDuration time.Duration `json:"prompt_eval_duration"` @@ -786,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu continue } - // slog.Debug("got line", "line", string(line)) evt, ok := bytes.CutPrefix(line, []byte("data: ")) if !ok { evt = line @@ -796,13 +820,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if err := json.Unmarshal(evt, &c); err != nil { return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } - // convert internal done reason to one of our standard api format done reasons - switch c.DoneReason { - case "limit": - c.DoneReason = "length" - default: - c.DoneReason = "stop" - } switch { case strings.TrimSpace(c.Content) == lastToken: diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index b5f59ae14..4f87d3f36 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -649,7 +649,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } else { if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Done: true, - DoneReason: seq.doneReason, + DoneReason: llm.ParseDoneReason(seq.doneReason), PromptEvalCount: seq.numPromptInputs, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), EvalCount: seq.numDecoded, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 5a1a14caf..5ad248ab1 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -629,7 +629,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } else { if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Done: true, - DoneReason: seq.doneReason, + DoneReason: llm.ParseDoneReason(seq.doneReason), PromptEvalCount: seq.numPromptInputs, PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), EvalCount: seq.numPredicted, diff --git a/server/routes.go b/server/routes.go index 059936249..d91d32535 100644 --- a/server/routes.go +++ b/server/routes.go @@ -312,7 +312,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { CreatedAt: time.Now().UTC(), Response: cr.Content, Done: cr.Done, - DoneReason: cr.DoneReason, + DoneReason: cr.DoneReason.String(), Metrics: api.Metrics{ PromptEvalCount: cr.PromptEvalCount, PromptEvalDuration: cr.PromptEvalDuration, @@ -1536,7 +1536,7 @@ func (s *Server) ChatHandler(c *gin.Context) { CreatedAt: time.Now().UTC(), Message: api.Message{Role: "assistant", Content: r.Content}, Done: r.Done, - DoneReason: r.DoneReason, + DoneReason: r.DoneReason.String(), Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, PromptEvalDuration: r.PromptEvalDuration,