Compare commits

...

1 Commits

Author SHA1 Message Date
Jiang Bohan
09397f70b2 fix(agent): resume codex thread across tasks on the same issue
Every other backend (Claude, Gemini, OpenCode, OpenClaw, Hermes) honors
ExecOptions.ResumeSessionID — only Codex didn't. That's why users on
the Codex runtime saw each new comment on an issue start a fresh Codex
conversation: the daemon persists Result.SessionID per (agent, issue)
and passes it back as PriorSessionID, but codex.go always called
thread/start and never populated SessionID, so the value round-tripped
as empty.

Wire the missing half:

- Extract startOrResumeThread on codexClient. When ResumeSessionID is
  set, call thread/resume (per the Codex app-server protocol), passing
  only cwd / model / developerInstructions overrides so the thread
  keeps its persisted model and reasoning effort. If resume fails
  (unknown thread, schema drift, transport error) fall back to
  thread/start so the task still runs on a fresh thread.
- Surface the live threadID as Result.SessionID on the final emit so
  the daemon stores it and feeds it back into ResumeSessionID on the
  next claim.

Tests drive the new helper through the fake stdin harness, covering:
fresh start, successful resume, fallback on resume error, fallback
when resume returns no thread ID, and surfacing of thread/start
failures.
2026-04-16 17:54:09 +08:00
2 changed files with 301 additions and 26 deletions

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"os/exec"
"path/filepath"
@@ -150,38 +151,22 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti
}
c.notify("initialized")
// 2. Start thread
threadResult, err := c.request(runCtx, "thread/start", map[string]any{
"model": nilIfEmpty(opts.Model),
"modelProvider": nil,
"profile": nil,
"cwd": opts.Cwd,
"approvalPolicy": nil,
"sandbox": nil,
"config": nil,
"baseInstructions": nil,
"developerInstructions": nilIfEmpty(opts.SystemPrompt),
"compactPrompt": nil,
"includeApplyPatchTool": nil,
"experimentalRawEvents": false,
"persistExtendedHistory": true,
})
// 2. Start a new thread, or resume the prior one for this issue. When
// resume fails (thread GCed on the server, schema drift, etc.) we fall
// back to a fresh thread so the task still makes progress.
threadID, resumed, err := c.startOrResumeThread(runCtx, opts, b.cfg.Logger)
if err != nil {
finalStatus = "failed"
finalError = fmt.Sprintf("codex thread/start failed: %v", err)
resCh <- Result{Status: finalStatus, Error: finalError, DurationMs: time.Since(startTime).Milliseconds()}
return
}
threadID := extractThreadID(threadResult)
if threadID == "" {
finalStatus = "failed"
finalError = "codex thread/start returned no thread ID"
finalError = err.Error()
resCh <- Result{Status: finalStatus, Error: finalError, DurationMs: time.Since(startTime).Milliseconds()}
return
}
c.threadID = threadID
b.cfg.Logger.Info("codex thread started", "thread_id", threadID)
if resumed {
b.cfg.Logger.Info("codex thread resumed", "thread_id", threadID)
} else {
b.cfg.Logger.Info("codex thread started", "thread_id", threadID)
}
// 3. Send turn and wait for completion
_, err = c.request(runCtx, "turn/start", map[string]any{
@@ -266,6 +251,7 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti
Status: finalStatus,
Output: finalOutput,
Error: finalError,
SessionID: threadID,
DurationMs: duration.Milliseconds(),
Usage: usageMap,
}
@@ -274,6 +260,58 @@ func (b *codexBackend) Execute(ctx context.Context, prompt string, opts ExecOpti
return &Session{Messages: msgCh, Result: resCh}, nil
}
// startOrResumeThread picks between Codex's thread/resume and thread/start
// based on opts.ResumeSessionID. When a prior thread ID is provided it first
// tries thread/resume; any error (unknown thread, schema mismatch, transport
// failure) is logged and the method falls back to thread/start so the task
// still executes. The returned threadID is what subsequent turn/start calls
// must reference, and resumed indicates whether the prior thread was picked
// up (only useful for logging).
func (c *codexClient) startOrResumeThread(ctx context.Context, opts ExecOptions, logger *slog.Logger) (string, bool, error) {
if priorThreadID := opts.ResumeSessionID; priorThreadID != "" {
// thread/resume reuses the thread's persisted model and reasoning
// effort; only override fields the daemon actually cares about.
resumeResult, err := c.request(ctx, "thread/resume", map[string]any{
"threadId": priorThreadID,
"cwd": opts.Cwd,
"model": nilIfEmpty(opts.Model),
"developerInstructions": nilIfEmpty(opts.SystemPrompt),
})
if err == nil {
if threadID := extractThreadID(resumeResult); threadID != "" {
return threadID, true, nil
}
logger.Warn("codex thread/resume returned no thread ID; falling back to thread/start", "prior_thread_id", priorThreadID)
} else {
logger.Warn("codex thread/resume failed; falling back to thread/start", "prior_thread_id", priorThreadID, "error", err)
}
}
startResult, err := c.request(ctx, "thread/start", map[string]any{
"model": nilIfEmpty(opts.Model),
"modelProvider": nil,
"profile": nil,
"cwd": opts.Cwd,
"approvalPolicy": nil,
"sandbox": nil,
"config": nil,
"baseInstructions": nil,
"developerInstructions": nilIfEmpty(opts.SystemPrompt),
"compactPrompt": nil,
"includeApplyPatchTool": nil,
"experimentalRawEvents": false,
"persistExtendedHistory": true,
})
if err != nil {
return "", false, fmt.Errorf("codex thread/start failed: %w", err)
}
threadID := extractThreadID(startResult)
if threadID == "" {
return "", false, fmt.Errorf("codex thread/start returned no thread ID")
}
return threadID, false, nil
}
// ── codexClient: JSON-RPC 2.0 transport ──
type codexClient struct {

View File

@@ -1,11 +1,14 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"sync"
"testing"
"time"
)
func newTestCodexClient(t *testing.T) (*codexClient, *fakeStdin, []Message) {
@@ -592,6 +595,240 @@ func TestNilIfEmpty(t *testing.T) {
}
}
// runRPCScript feeds JSON-RPC responses back to the codexClient by matching
// each method call written to stdin against the script, and emitting the
// scripted response via c.handleLine. It returns once all scripted calls have
// been served.
type rpcResponse struct {
method string // expected request method
result json.RawMessage // success result body (mutually exclusive with errMsg)
errMsg string // non-empty → respond with JSON-RPC error object
errCode int // JSON-RPC error code when errMsg is set
assertFn func(t *testing.T, params map[string]any)
}
// drainRPCScript spins up a goroutine that watches fs.Lines() for new outbound
// requests and, for each one, injects the scripted response via c.handleLine.
// It returns a stop function that blocks until the script is exhausted or the
// test terminates.
func drainRPCScript(t *testing.T, c *codexClient, fs *fakeStdin, script []rpcResponse) func() {
t.Helper()
done := make(chan struct{})
go func() {
defer close(done)
seen := 0
deadline := time.Now().Add(2 * time.Second)
for seen < len(script) {
lines := fs.Lines()
for seen < len(lines) && seen < len(script) {
var req struct {
ID int `json:"id"`
Method string `json:"method"`
Params json.RawMessage `json:"params"`
}
if err := json.Unmarshal([]byte(lines[seen]), &req); err != nil {
t.Errorf("drainRPCScript: unmarshal request %d: %v", seen, err)
return
}
expected := script[seen]
if req.Method != expected.method {
t.Errorf("drainRPCScript: call %d method = %q, want %q", seen, req.Method, expected.method)
return
}
if expected.assertFn != nil {
var params map[string]any
_ = json.Unmarshal(req.Params, &params)
expected.assertFn(t, params)
}
var resp string
if expected.errMsg != "" {
resp = fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"error":{"code":%d,"message":%q}}`, req.ID, expected.errCode, expected.errMsg)
} else {
resp = fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"result":%s}`, req.ID, string(expected.result))
}
c.handleLine(resp)
seen++
}
if seen < len(script) {
if time.Now().After(deadline) {
t.Errorf("drainRPCScript: timed out after %d/%d responses", seen, len(script))
return
}
time.Sleep(5 * time.Millisecond)
}
}
}()
return func() {
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("drainRPCScript did not finish")
}
}
}
func TestCodexStartOrResumeThreadStartsFresh(t *testing.T) {
t.Parallel()
c, fs, _ := newTestCodexClient(t)
wait := drainRPCScript(t, c, fs, []rpcResponse{
{
method: "thread/start",
result: json.RawMessage(`{"thread":{"id":"thr_fresh"}}`),
assertFn: func(t *testing.T, params map[string]any) {
if params["cwd"] != "/work" {
t.Errorf("cwd = %v, want /work", params["cwd"])
}
if params["persistExtendedHistory"] != true {
t.Error("expected persistExtendedHistory=true on thread/start")
}
},
},
})
defer wait()
threadID, resumed, err := c.startOrResumeThread(context.Background(), ExecOptions{Cwd: "/work"}, slog.Default())
if err != nil {
t.Fatalf("startOrResumeThread: %v", err)
}
if threadID != "thr_fresh" {
t.Errorf("threadID = %q, want thr_fresh", threadID)
}
if resumed {
t.Error("resumed should be false when no prior session is provided")
}
}
func TestCodexStartOrResumeThreadResumesPriorThread(t *testing.T) {
t.Parallel()
c, fs, _ := newTestCodexClient(t)
wait := drainRPCScript(t, c, fs, []rpcResponse{
{
method: "thread/resume",
result: json.RawMessage(`{"thread":{"id":"thr_prior"}}`),
assertFn: func(t *testing.T, params map[string]any) {
if params["threadId"] != "thr_prior" {
t.Errorf("threadId = %v, want thr_prior", params["threadId"])
}
if params["cwd"] != "/work" {
t.Errorf("cwd = %v, want /work", params["cwd"])
}
},
},
})
defer wait()
threadID, resumed, err := c.startOrResumeThread(
context.Background(),
ExecOptions{Cwd: "/work", ResumeSessionID: "thr_prior"},
slog.Default(),
)
if err != nil {
t.Fatalf("startOrResumeThread: %v", err)
}
if threadID != "thr_prior" {
t.Errorf("threadID = %q, want thr_prior", threadID)
}
if !resumed {
t.Error("expected resumed=true when thread/resume succeeded")
}
}
func TestCodexStartOrResumeThreadFallsBackOnResumeError(t *testing.T) {
t.Parallel()
c, fs, _ := newTestCodexClient(t)
wait := drainRPCScript(t, c, fs, []rpcResponse{
{
method: "thread/resume",
errMsg: "unknown thread",
errCode: -32602,
},
{
method: "thread/start",
result: json.RawMessage(`{"thread":{"id":"thr_new"}}`),
},
})
defer wait()
threadID, resumed, err := c.startOrResumeThread(
context.Background(),
ExecOptions{Cwd: "/work", ResumeSessionID: "thr_stale"},
slog.Default(),
)
if err != nil {
t.Fatalf("startOrResumeThread: %v", err)
}
if threadID != "thr_new" {
t.Errorf("threadID = %q, want thr_new (fresh thread after fallback)", threadID)
}
if resumed {
t.Error("expected resumed=false after falling back to thread/start")
}
}
func TestCodexStartOrResumeThreadFallsBackWhenResumeReturnsNoID(t *testing.T) {
t.Parallel()
c, fs, _ := newTestCodexClient(t)
wait := drainRPCScript(t, c, fs, []rpcResponse{
{
method: "thread/resume",
result: json.RawMessage(`{"thread":{}}`),
},
{
method: "thread/start",
result: json.RawMessage(`{"thread":{"id":"thr_new"}}`),
},
})
defer wait()
threadID, resumed, err := c.startOrResumeThread(
context.Background(),
ExecOptions{ResumeSessionID: "thr_prior"},
slog.Default(),
)
if err != nil {
t.Fatalf("startOrResumeThread: %v", err)
}
if threadID != "thr_new" {
t.Errorf("threadID = %q, want thr_new", threadID)
}
if resumed {
t.Error("expected resumed=false when resume yielded no thread ID")
}
}
func TestCodexStartOrResumeThreadStartFailureSurfaces(t *testing.T) {
t.Parallel()
c, fs, _ := newTestCodexClient(t)
wait := drainRPCScript(t, c, fs, []rpcResponse{
{
method: "thread/start",
errMsg: "boom",
errCode: -32000,
},
})
defer wait()
_, _, err := c.startOrResumeThread(context.Background(), ExecOptions{}, slog.Default())
if err == nil {
t.Fatal("expected error when thread/start fails")
}
if !strings.Contains(err.Error(), "thread/start") {
t.Errorf("error should mention thread/start, got %v", err)
}
}
func TestCodexProtocolDetectionLegacyBlocksRaw(t *testing.T) {
t.Parallel()