use done reason enum

This commit is contained in:
Bruce MacDonald 2025-03-19 10:41:51 -07:00
parent 22f2f6e229
commit 771c88b3ad
4 changed files with 30 additions and 13 deletions

View File

@ -675,9 +675,34 @@ type CompletionRequest struct {
Grammar string // set before sending the request to the subprocess
}
// DoneReason represents the reason why a completion response is done
type DoneReason string
const (
// DoneReasonStop indicates the completion stopped naturally
DoneReasonStop DoneReason = "stop"
// DoneReasonLength indicates the completion stopped due to length limits
DoneReasonLength DoneReason = "length"
)
func (d DoneReason) String() string {
return string(d)
}
// ParseDoneReason converts a string to a DoneReason type
// If the string doesn't match any known reason, it defaults to DoneReasonStop
func ParseDoneReason(reason string) DoneReason {
switch reason {
case "limit", "length":
return DoneReasonLength
default:
return DoneReasonStop
}
}
type CompletionResponse struct {
Content string `json:"content"`
DoneReason string `json:"done_reason"`
DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
@ -786,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue
}
// slog.Debug("got line", "line", string(line))
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
evt = line
@ -796,13 +820,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
// convert internal done reason to one of our standard api format done reasons
switch c.DoneReason {
case "limit":
c.DoneReason = "length"
default:
c.DoneReason = "stop"
}
switch {
case strings.TrimSpace(c.Content) == lastToken:

View File

@ -649,7 +649,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} else {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: seq.doneReason,
DoneReason: llm.ParseDoneReason(seq.doneReason),
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numDecoded,

View File

@ -629,7 +629,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} else {
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: seq.doneReason,
DoneReason: llm.ParseDoneReason(seq.doneReason),
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
EvalCount: seq.numPredicted,

View File

@ -312,7 +312,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
DoneReason: cr.DoneReason,
DoneReason: cr.DoneReason.String(),
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
@ -1536,7 +1536,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
DoneReason: r.DoneReason,
DoneReason: r.DoneReason.String(),
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,