routes: structured outputs for gpt-oss (#12460)

This commit is contained in:
Parth Sareen
2025-10-08 19:13:38 -07:00
committed by GitHub
parent 1b91d4dda1
commit 77060d462c
2 changed files with 377 additions and 64 deletions

View File

@@ -1979,88 +1979,167 @@ func (s *Server) ChatHandler(c *gin.Context) {
toolParser = tools.NewParser(m.Template.Template, req.Tools) toolParser = tools.NewParser(m.Template.Template, req.Tools)
} }
type structuredOutputsState int
const (
structuredOutputsState_None structuredOutputsState = iota
structuredOutputsState_ReadyToApply
structuredOutputsState_Applying
)
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ structuredOutputsState := structuredOutputsState_None
Prompt: prompt,
Images: images, for {
Format: req.Format, var tb strings.Builder
Options: opts,
}, func(r llm.CompletionResponse) { currentFormat := req.Format
res := api.ChatResponse{ // structured outputs via double request is enabled when:
Model: req.Model, // 1. the model supports the thinking capability and
CreatedAt: time.Now().UTC(), // 2. it uses a built-in parser or our generic thinking parser
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done, // Note that the current approach does not work for (potential future)
Metrics: api.Metrics{ // non-thinking models that emit anything before actual content. This
PromptEvalCount: r.PromptEvalCount, // current approach uses the transition from parsed thinking content to
PromptEvalDuration: r.PromptEvalDuration, // parsed non-thinking content as the signal to turn constraining on
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration, if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
}, currentFormat = nil
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
if builtinParser != nil { // sets up new context given parent context per request
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) ctx, cancel := context.WithCancel(c.Request.Context())
err := r.Completion(ctx, llm.CompletionRequest{
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done) Prompt: prompt,
if err != nil { Images: images,
ch <- gin.H{"error": err.Error()} Format: currentFormat,
return Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
res.Message.Content = content if builtinParser != nil {
res.Message.Thinking = thinking slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
res.Message.ToolCalls = toolCalls
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) if err != nil {
ch <- res ch <- gin.H{"error": err.Error()}
} else { return
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser) }
}
return
}
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.Thinking = thinkingContent
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content res.Message.Content = content
} else if len(toolCalls) > 0 { res.Message.Thinking = thinking
res.Message.ToolCalls = toolCalls res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" { tb.WriteString(thinking)
// don't return // we are now receiving content from the model - we should start applying structured outputs
} else { if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" {
if r.Done { structuredOutputsState = structuredOutputsState_ReadyToApply
res.Message.Content = toolParser.Content() cancel()
return
}
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
} }
return return
} }
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Thinking = thinkingContent
tb.WriteString(thinkingContent)
// emit the collected thinking text before restarting with structured outputs and clear unstructured content
// to avoid leaking mixed tokens like "</think>Hello"
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" {
structuredOutputsState = structuredOutputsState_ReadyToApply
res.Message.Content = ""
ch <- res
cancel()
return
}
res.Message.Content = remainingContent
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
} else {
if r.Done {
res.Message.Content = toolParser.Content()
ch <- res
}
return
}
}
ch <- res
})
if err != nil {
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
// only ignores error if it's a context cancellation due to setting structured outputs
} else {
ch <- gin.H{"error": err.Error()}
return
}
} }
ch <- res // ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the
}); err != nil { if structuredOutputsState == structuredOutputsState_ReadyToApply {
ch <- gin.H{"error": err.Error()} structuredOutputsState = structuredOutputsState_Applying
msg := api.Message{
Role: "assistant",
Thinking: tb.String(),
}
msgs = append(msgs, msg)
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
if err != nil {
slog.Error("chat prompt error applying structured outputs", "error", err)
ch <- gin.H{"error": err.Error()}
return
}
// force constraining by terminating thinking header, the parser is already at this state
// when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the
// model continue thinking or ending thinking and outputting the final message.
// TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs.
if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") {
prompt += "<|end|><|start|>assistant<|channel|>final<|message|>"
}
continue
}
break
} }
}() }()

View File

@@ -1191,4 +1191,238 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got) t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
} }
}) })
t.Run("structured outputs restart non-stream", func(t *testing.T) {
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
wg sync.WaitGroup
)
wg.Add(2)
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
requestsMu.Lock()
requests = append(requests, r)
callNum := len(requests)
requestsMu.Unlock()
switch callNum {
case 1:
fn(llm.CompletionResponse{
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
})
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
t.Fatalf("timeout waiting for structured outputs cancellation")
return nil
}
case 2:
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
default:
t.Fatalf("unexpected number of completion calls: %d", callNum)
return nil
}
}
think := true
streamRequest := false
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
wg.Wait()
mock.CompletionFn = nil
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if len(requests) != 2 {
t.Fatalf("expected two completion calls, got %d", len(requests))
}
if requests[0].Format != nil {
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
}
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
t.Errorf("expected second completion format to match original format")
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.Thinking != "I am thinking through this problem. " {
t.Errorf("expected thinking %q, got %q", "I am thinking through this problem. ", resp.Message.Thinking)
}
if resp.Message.Content != `{"answer":"42"}` {
t.Errorf("expected content %q, got %q", `{"answer":"42"}`, resp.Message.Content)
}
if !resp.Done {
t.Errorf("expected response to be done")
}
if resp.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", resp.DoneReason)
}
})
t.Run("structured outputs restart streaming", func(t *testing.T) {
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
wg sync.WaitGroup
)
wg.Add(2)
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
requestsMu.Lock()
requests = append(requests, r)
callNum := len(requests)
requestsMu.Unlock()
switch callNum {
case 1:
fn(llm.CompletionResponse{
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
})
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
t.Fatalf("timeout waiting for structured outputs cancellation")
return nil
}
case 2:
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
default:
t.Fatalf("unexpected number of completion calls: %d", callNum)
return nil
}
}
think := true
streamRequest := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
wg.Wait()
mock.CompletionFn = nil
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if len(requests) != 2 {
t.Fatalf("expected two completion calls, got %d", len(requests))
}
if requests[0].Format != nil {
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
}
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
t.Errorf("expected second completion format to match original format")
}
decoder := json.NewDecoder(w.Body)
var events []api.ChatResponse
for {
var event api.ChatResponse
if err := decoder.Decode(&event); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
events = append(events, event)
if event.Done {
break
}
}
if len(events) < 2 {
t.Fatalf("expected at least two streaming events, got %d", len(events))
}
first := events[0]
if first.Message.Thinking != "I am thinking through this problem. " {
t.Errorf("expected first event thinking %q, got %q", "I am thinking through this problem. ", first.Message.Thinking)
}
if first.Message.Content != "" {
t.Errorf("expected first event content to be empty, got %q", first.Message.Content)
}
if first.Done {
t.Error("expected first event to be non-terminal")
}
last := events[len(events)-1]
if last.Message.Thinking != "" {
t.Errorf("expected final event thinking to be empty, got %q", last.Message.Thinking)
}
if last.Message.Content != `{"answer":"42"}` {
t.Errorf("expected final event content %q, got %q", `{"answer":"42"}`, last.Message.Content)
}
if !last.Done {
t.Error("expected final event to be done")
}
if last.DoneReason != "stop" {
t.Errorf("expected final done reason stop, got %s", last.DoneReason)
}
})
} }