diff --git a/server/internal/daemon/daemon.go b/server/internal/daemon/daemon.go index e5a35cccc..974f6e4dd 100644 --- a/server/internal/daemon/daemon.go +++ b/server/internal/daemon/daemon.go @@ -91,6 +91,7 @@ type Daemon struct { reregisterNextAttempt map[string]time.Time // workspace_id -> earliest time the next re-register attempt may run cancelFunc context.CancelFunc // set by Run(); called by triggerRestart + rootCtx context.Context // set by Run(); used by long-running recoveries that must survive per-runtime ctx cancellation restartBinary string // non-empty after a successful update; path to the new binary updating atomic.Bool // prevents concurrent update attempts activeTasks atomic.Int64 // number of tasks currently in handleTask; exposed via /health @@ -170,14 +171,22 @@ const reregisterFailureBackoff = 60 * time.Second // // - keys an in-flight set on runtimeID to drop concurrent calls for the same // ID after the first one is already cleaning up; and -// - keys a per-workspace next-attempt timestamp on workspaceID so the second -// stale runtime in the same workspace skips registerRuntimesForWorkspace -// when the first one's re-register is still inside the coalesce window. +// - keys a per-workspace next-attempt timestamp on workspaceID so that +// concurrent recoveries triggered by the SAME initial event coalesce to a +// single registerRuntimesForWorkspace call. The slot is cleared on success +// so a later distinct runtime deletion in the same workspace can trigger +// its own recovery without waiting for the coalesce window to expire. // // On failure of the underlying re-register, the next-attempt timestamp is // extended by reregisterFailureBackoff so we don't replace a server-side log -// flood with a daemon-side register flood. -func (d *Daemon) handleRuntimeGone(ctx context.Context, runtimeID string) { +// flood with a daemon-side register flood. workspaceSyncLoop will retry +// independently every DefaultWorkspaceSyncInterval as a safety net. +// +// The recovery HTTP call uses the daemon root context, not the caller's. The +// heartbeat path's per-runtime ctx is cancelled by notifyRuntimeSetChanged the +// moment we prune the dead UUID, and if we forwarded that ctx the in-flight +// register would self-cancel mid-flight. +func (d *Daemon) handleRuntimeGone(runtimeID string) { if runtimeID == "" { return } @@ -208,8 +217,11 @@ func (d *Daemon) handleRuntimeGone(ctx context.Context, runtimeID string) { d.notifyRuntimeSetChanged() // Per-workspace coalescing: claim the slot atomically. The first caller - // past this check is the only one that will run registerRuntimesForWorkspace - // for this workspace inside the coalesce window. + // past this check is the only one that will run + // registerRuntimesForWorkspace while the coalesce window is open. We + // clear the slot on success so a separate later deletion in the same + // workspace is NOT suppressed; the inflight set above is what keeps two + // callers from racing the same recovery. now := time.Now() d.runtimeGoneMu.Lock() if next, ok := d.reregisterNextAttempt[workspaceID]; ok && now.Before(next) { @@ -221,7 +233,7 @@ func (d *Daemon) handleRuntimeGone(ctx context.Context, runtimeID string) { d.reregisterNextAttempt[workspaceID] = now.Add(reregisterCoalesceWindow) d.runtimeGoneMu.Unlock() - if err := d.reregisterWorkspaceAfterRuntimeGone(ctx, workspaceID); err != nil { + if err := d.reregisterWorkspaceAfterRuntimeGone(d.recoveryContext(), workspaceID); err != nil { d.runtimeGoneMu.Lock() d.reregisterNextAttempt[workspaceID] = time.Now().Add(reregisterFailureBackoff) d.runtimeGoneMu.Unlock() @@ -230,7 +242,25 @@ func (d *Daemon) handleRuntimeGone(ctx context.Context, runtimeID string) { // failure here is not a stuck state — just an extra wait. d.logger.Warn("re-register after runtime gone failed", "workspace_id", workspaceID, "error", err) + return } + // Success: clear the coalesce slot so a future distinct runtime deletion + // in this workspace can trigger its own recovery immediately. The + // inflight set on runtimeID still prevents same-event stampedes. + d.runtimeGoneMu.Lock() + delete(d.reregisterNextAttempt, workspaceID) + d.runtimeGoneMu.Unlock() +} + +// recoveryContext returns the daemon root context for long-running recovery +// HTTP calls (re-register, recover-orphans) that must survive the heartbeat +// loop tearing down a per-runtime context. Falls back to Background when the +// daemon was not started via Run(), e.g. unit-test fixtures. +func (d *Daemon) recoveryContext() context.Context { + if d.rootCtx != nil { + return d.rootCtx + } + return context.Background() } // removeStaleRuntime drops a runtime ID from its owning workspace's runtimeIDs @@ -290,27 +320,51 @@ func (d *Daemon) workspaceNeedsRuntimeRecovery(workspaceID string) bool { } // reregisterWorkspaceAfterRuntimeGone calls registerRuntimesForWorkspace and -// merges the resulting runtime IDs into the existing workspaceState. The -// workspaceState pointer is NEVER replaced (see syncWorkspacesFromAPI's -// invariant about repoRefreshMu). +// updates the existing workspaceState in place. The register response is +// authoritative for this workspace's runtime set — every configured provider +// is included, with UpsertAgentRuntime returning the same row ID for surviving +// providers and a fresh ID for any that were deleted server-side. Replacing +// (rather than appending) is required: a partial recovery, where only one +// runtime in a multi-provider workspace was deleted, would otherwise produce +// duplicates for every provider that wasn't deleted. +// +// The workspaceState pointer is NEVER replaced (see syncWorkspacesFromAPI's +// invariant about repoRefreshMu). Only fields are mutated. func (d *Daemon) reregisterWorkspaceAfterRuntimeGone(ctx context.Context, workspaceID string) error { resp, err := d.registerRuntimesForWorkspace(ctx, workspaceID) if err != nil { return fmt.Errorf("register runtimes: %w", err) } - runtimeIDs := make([]string, 0, len(resp.Runtimes)) + newIDs := make([]string, 0, len(resp.Runtimes)) + newIDSet := make(map[string]struct{}, len(resp.Runtimes)) + for _, rt := range resp.Runtimes { + newIDs = append(newIDs, rt.ID) + newIDSet[rt.ID] = struct{}{} + } + d.mu.Lock() ws, ok := d.workspaces[workspaceID] if !ok { d.mu.Unlock() return fmt.Errorf("workspace %s no longer tracked", workspaceID) } + // Drop runtimeIndex entries for prior runtime IDs that the server did not + // return — typically there are none for upsert-on-existing-provider, but + // a daemon config change (provider removed) would leak entries otherwise. + for _, oldID := range ws.runtimeIDs { + if _, kept := newIDSet[oldID]; !kept { + delete(d.runtimeIndex, oldID) + } + } for _, rt := range resp.Runtimes { d.runtimeIndex[rt.ID] = rt - runtimeIDs = append(runtimeIDs, rt.ID) } - ws.runtimeIDs = append(ws.runtimeIDs, runtimeIDs...) + // Response is authoritative — replace, do not append. Replacing also + // catches the rare case where UpsertAgentRuntime returns a different ID + // for a surviving provider (e.g. schema change); the daemon converges on + // what the server says without leaving stale heartbeat goroutines. + ws.runtimeIDs = newIDs if resp.ReposVersion != "" { ws.reposVersion = resp.ReposVersion ws.allowedRepoURLs = repoAllowlist(resp.Repos) @@ -320,7 +374,7 @@ func (d *Daemon) reregisterWorkspaceAfterRuntimeGone(ctx context.Context, worksp } d.mu.Unlock() - for _, rid := range runtimeIDs { + for _, rid := range newIDs { d.logger.Info("re-registered runtime after server-side deletion", "workspace_id", workspaceID, "runtime_id", rid) } @@ -328,7 +382,7 @@ func (d *Daemon) reregisterWorkspaceAfterRuntimeGone(ctx context.Context, worksp // Tell the server about any tasks the previous (now-deleted) runtime // was working on, mirroring the registration path's recover-orphans call. - for _, rid := range runtimeIDs { + for _, rid := range newIDs { if err := d.client.RecoverOrphans(ctx, rid); err != nil { d.logger.Warn("recover-orphans after re-register failed", "runtime_id", rid, "error", err) @@ -431,6 +485,7 @@ func (d *Daemon) Run(ctx context.Context) error { // Wrap context so handleUpdate can cancel the daemon for restart. ctx, cancel := context.WithCancel(ctx) d.cancelFunc = cancel + d.rootCtx = ctx // Bind health port early to detect another running daemon. healthLn, err := d.listenHealth() @@ -1072,9 +1127,10 @@ func (d *Daemon) runHeartbeatTick(ctx context.Context, rid string) { if isRuntimeNotFoundError(err) { // Server says this runtime is gone — recover instead of // looping on the dead UUID. handleRuntimeGone coalesces - // concurrent callers, so this is safe to call from every - // heartbeat tick. - go d.handleRuntimeGone(ctx, rid) + // concurrent callers and runs the recovery HTTP call under + // the daemon root context so notifyRuntimeSetChanged + // tearing down this heartbeat goroutine cannot abort it. + go d.handleRuntimeGone(rid) return } d.logger.Warn("heartbeat failed", "runtime_id", rid, "error", err) @@ -1085,7 +1141,7 @@ func (d *Daemon) runHeartbeatTick(ctx context.Context, rid string) { // The WS path returns a successful ack with RuntimeGone=true for the // same scenario; treat it the same way here in case HTTP starts // surfacing this signal too. - go d.handleRuntimeGone(ctx, rid) + go d.handleRuntimeGone(rid) return } d.handleHeartbeatActions(ctx, rid, resp) @@ -1609,7 +1665,7 @@ func (d *Daemon) runRuntimePoller( // the poller; the runtime-set watcher will tear this // goroutine down via pollerCtx once the workspace is // re-registered with a new runtime ID. - go d.handleRuntimeGone(parentCtx, rid) + go d.handleRuntimeGone(rid) return } d.logger.Warn("claim task failed", "runtime_id", rid, "error", err) diff --git a/server/internal/daemon/runtime_gone_test.go b/server/internal/daemon/runtime_gone_test.go index 6f6207e5d..500f8f4b2 100644 --- a/server/internal/daemon/runtime_gone_test.go +++ b/server/internal/daemon/runtime_gone_test.go @@ -3,6 +3,7 @@ package daemon import ( "context" "encoding/json" + "fmt" "log/slog" "net/http" "net/http/httptest" @@ -176,7 +177,7 @@ func TestHandleRuntimeGone_PrunesAndReregisters(t *testing.T) { d.runtimeIndex["rt-old"] = Runtime{ID: "rt-old"} d.wsHBLastAck["rt-old"] = time.Now() - d.handleRuntimeGone(context.Background(), "rt-old") + d.handleRuntimeGone("rt-old") if got := d.runtimeIndex["rt-old"]; got.ID != "" { t.Fatalf("rt-old still present in runtimeIndex: %+v", got) @@ -215,7 +216,7 @@ func TestHandleRuntimeGone_CoalescesConcurrentCallers(t *testing.T) { wg.Add(1) go func(id string) { defer wg.Done() - d.handleRuntimeGone(context.Background(), id) + d.handleRuntimeGone(id) }(rid) } wg.Wait() @@ -253,8 +254,8 @@ func TestHandleRuntimeGone_BackoffOnFailure(t *testing.T) { d.runtimeIndex["rt-1"] = Runtime{ID: "rt-1"} d.runtimeIndex["rt-2"] = Runtime{ID: "rt-2"} - d.handleRuntimeGone(context.Background(), "rt-1") - d.handleRuntimeGone(context.Background(), "rt-2") + d.handleRuntimeGone("rt-1") + d.handleRuntimeGone("rt-2") if got := registerCount.Load(); got != 1 { t.Fatalf("register endpoint called %d times on failure path, want 1 (second call should be coalesced)", got) @@ -347,3 +348,305 @@ func TestWorkspaceNeedsRuntimeRecovery(t *testing.T) { t.Fatalf("untracked workspace should NOT need recovery") } } + +// multiProviderRegisterFixture mirrors handleRuntimeGoneFixture but speaks the +// upsert semantics of UpsertAgentRuntime: surviving providers keep their +// runtime IDs across re-registers, deleted ones get a fresh ID. The fake +// server is the source of truth and rewrites its own knowledge of which +// providers are alive each time a runtime is deleted. +// +// markDeleted(rid) emulates a UI Delete by removing the row server-side and +// returning a brand-new ID for that provider on the next register call. +type multiProviderRegisterFixture struct { + daemon *Daemon + server *httptest.Server + registerCount *atomic.Int64 + mu sync.Mutex + // providerToID maps provider -> current server-side runtime ID. The fake + // register handler reads/mutates this so the test reflects realistic + // upsert behavior. + providerToID map[string]string + idCounter int +} + +func newMultiProviderRegisterFixture(t *testing.T, providers map[string]string) *multiProviderRegisterFixture { + t.Helper() + + fx := &multiProviderRegisterFixture{ + providerToID: make(map[string]string, len(providers)), + } + for p, id := range providers { + fx.providerToID[p] = id + } + + var registerCount atomic.Int64 + fx.registerCount = ®isterCount + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/api/daemon/register": + registerCount.Add(1) + fx.mu.Lock() + runtimes := make([]Runtime, 0, len(fx.providerToID)) + for provider, id := range fx.providerToID { + if id == "" { + // Provider was marked deleted; mint a fresh ID + // (the UpsertAgentRuntime INSERT branch). + fx.idCounter++ + id = fmt.Sprintf("%s-new-%d", provider, fx.idCounter) + fx.providerToID[provider] = id + } + runtimes = append(runtimes, Runtime{ + ID: id, Name: provider, Provider: provider, Status: "online", + }) + } + fx.mu.Unlock() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(RegisterResponse{ + Runtimes: runtimes, + Repos: []RepoData{}, + }) + case strings.HasSuffix(r.URL.Path, "/recover-orphans"): + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusOK) + } + })) + t.Cleanup(srv.Close) + + d := freshDaemon(srv.URL) + d.cfg.Agents = make(map[string]AgentEntry, len(providers)) + for p := range providers { + d.cfg.Agents[p] = AgentEntry{Path: "/usr/bin/true"} + } + t.Cleanup(stubAgentVersion(t)) + fx.daemon = d + fx.server = srv + return fx +} + +// markDeleted simulates server-side runtime deletion: the next register call +// will mint a new ID for this provider, matching the UI Delete + re-register +// path's UpsertAgentRuntime INSERT branch. +func (fx *multiProviderRegisterFixture) markDeleted(provider string) { + fx.mu.Lock() + defer fx.mu.Unlock() + fx.providerToID[provider] = "" +} + +func TestHandleRuntimeGone_PartialWorkspaceRecoveryKeepsSibling(t *testing.T) { + // Workspace has two providers, only one runtime is deleted. The siblings + // must NOT end up duplicated in workspaceState.runtimeIDs — that would + // leak through allRuntimeIDs(), deregister(), and re-recovery state. + // This is the regression test for Finding #3 (register response is + // authoritative for the workspace's runtime set, not an append). + fx := newMultiProviderRegisterFixture(t, map[string]string{ + "claude": "rt-claude-1", + "codex": "rt-codex-1", + }) + d := fx.daemon + d.workspaces["ws-1"] = &workspaceState{ + workspaceID: "ws-1", + runtimeIDs: []string{"rt-claude-1", "rt-codex-1"}, + } + d.runtimeIndex["rt-claude-1"] = Runtime{ID: "rt-claude-1", Provider: "claude"} + d.runtimeIndex["rt-codex-1"] = Runtime{ID: "rt-codex-1", Provider: "codex"} + + // Only the claude runtime gets deleted server-side. + fx.markDeleted("claude") + d.handleRuntimeGone("rt-claude-1") + + got := append([]string(nil), d.workspaces["ws-1"].runtimeIDs...) + if len(got) != 2 { + t.Fatalf("workspace runtimeIDs has %d entries after partial recovery; want 2; got %v", len(got), got) + } + // Set comparison: must contain rt-codex-1 (surviving) and a freshly + // minted claude id, with NO duplicates. + seen := make(map[string]int, len(got)) + for _, id := range got { + seen[id]++ + } + for id, count := range seen { + if count != 1 { + t.Fatalf("duplicate runtime id %q (count=%d) after partial recovery: %v", id, count, got) + } + } + if _, ok := seen["rt-codex-1"]; !ok { + t.Fatalf("surviving codex runtime missing from workspace state after recovery: %v", got) + } + if _, ok := seen["rt-claude-1"]; ok { + t.Fatalf("deleted claude runtime should not be in workspace state: %v", got) + } + // And the runtimeIndex must reflect the same: codex kept, claude-1 dropped. + if _, ok := d.runtimeIndex["rt-claude-1"]; ok { + t.Fatalf("rt-claude-1 still in runtimeIndex after deletion") + } + if _, ok := d.runtimeIndex["rt-codex-1"]; !ok { + t.Fatalf("rt-codex-1 dropped from runtimeIndex during partial recovery") + } +} + +func TestHandleRuntimeGone_DistinctDeletionsWithinCoalesceWindowBothRecover(t *testing.T) { + // Two sequential, distinct runtime deletions in the same workspace fired + // within the 30s coalesce window. Each deletion must trigger its own + // re-register: success on call #1 must NOT suppress call #2. Regression + // for Finding #2 (success-case clear of reregisterNextAttempt). + fx := newMultiProviderRegisterFixture(t, map[string]string{ + "claude": "rt-claude-1", + "codex": "rt-codex-1", + }) + d := fx.daemon + d.workspaces["ws-1"] = &workspaceState{ + workspaceID: "ws-1", + runtimeIDs: []string{"rt-claude-1", "rt-codex-1"}, + } + d.runtimeIndex["rt-claude-1"] = Runtime{ID: "rt-claude-1", Provider: "claude"} + d.runtimeIndex["rt-codex-1"] = Runtime{ID: "rt-codex-1", Provider: "codex"} + + // Sequential, NOT concurrent: the first call fully completes before the + // second starts, so the in-flight set never collides. + fx.markDeleted("claude") + d.handleRuntimeGone("rt-claude-1") + + if got := fx.registerCount.Load(); got != 1 { + t.Fatalf("after first deletion: register called %d times, want 1", got) + } + // Inspect the new claude id the fake assigned, so we can detect that + // the second recovery actually ran register again. + fx.mu.Lock() + claudeIDAfterFirst := fx.providerToID["claude"] + fx.mu.Unlock() + + // Now delete codex within the coalesce window (effectively t<1s after + // the first recovery), simulating a user deleting a second runtime + // shortly after the first. + fx.markDeleted("codex") + d.handleRuntimeGone("rt-codex-1") + + if got := fx.registerCount.Load(); got != 2 { + t.Fatalf("after second distinct deletion: register called %d times, want 2 (coalesce window must clear on success)", got) + } + got := append([]string(nil), d.workspaces["ws-1"].runtimeIDs...) + if len(got) != 2 { + t.Fatalf("workspace runtimeIDs after both recoveries = %v, want 2 entries", got) + } + seen := make(map[string]int, len(got)) + for _, id := range got { + seen[id]++ + } + for id, count := range seen { + if count != 1 { + t.Fatalf("duplicate runtime id %q after sequential recoveries: %v", id, got) + } + } + if _, ok := seen[claudeIDAfterFirst]; !ok { + t.Fatalf("claude id from first recovery missing after second deletion of codex: have %v, expected to keep %q", got, claudeIDAfterFirst) + } +} + +func TestHandleRuntimeGone_RecoveryContextSurvivesCallerCancellation(t *testing.T) { + // Regression for Finding #1: handleRuntimeGone must not use the per- + // runtime heartbeat ctx for the register HTTP call. notifyRuntimeSetChanged + // tears that ctx down as soon as we prune the dead runtime, so forwarding + // it would self-cancel the in-flight register. + // + // We assert by inspecting the register handler's request context: it + // must not be Done when the daemon's rootCtx is alive, regardless of what + // upstream contexts (heartbeat, poller, WS) are doing. + var observedCancelled atomic.Bool + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/daemon/register" { + // Inspect the inbound request ctx. If handleRuntimeGone had + // forwarded a cancelled caller ctx, this would be Done. + select { + case <-r.Context().Done(): + observedCancelled.Store(true) + default: + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(RegisterResponse{ + Runtimes: []Runtime{{ID: "rt-new", Name: "claude", Provider: "claude", Status: "online"}}, + }) + return + } + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + + d := freshDaemon(srv.URL) + d.cfg.Agents = map[string]AgentEntry{"claude": {Path: "/usr/bin/true"}} + t.Cleanup(stubAgentVersion(t)) + + // rootCtx is what handleRuntimeGone uses for recovery. We keep it alive. + rootCtx, rootCancel := context.WithCancel(context.Background()) + defer rootCancel() + d.rootCtx = rootCtx + + d.workspaces["ws-1"] = &workspaceState{workspaceID: "ws-1", runtimeIDs: []string{"rt-old"}} + d.runtimeIndex["rt-old"] = Runtime{ID: "rt-old"} + + d.handleRuntimeGone("rt-old") + + if observedCancelled.Load() { + t.Fatalf("register HTTP call ran with a cancelled context — recovery would self-cancel under runtime-set churn") + } + if got := d.workspaces["ws-1"].runtimeIDs; len(got) != 1 || got[0] != "rt-new" { + t.Fatalf("workspace runtimeIDs after recovery = %v, want [rt-new]", got) + } +} + +func TestHandleRuntimeGone_RecoveryContextStopsOnDaemonShutdown(t *testing.T) { + // Companion to RecoveryContextSurvivesCallerCancellation: when the daemon + // IS shutting down, recovery must abort promptly instead of holding the + // HTTP call open until its 30s client timeout. We bound the server + // handler with a short safety timeout so test cleanup never hangs on a + // stuck connection — the assertion is on the daemon-side return time, + // not on server-side context propagation. + registerEntered := make(chan struct{}, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/daemon/register" { + select { + case registerEntered <- struct{}{}: + default: + } + select { + case <-r.Context().Done(): + case <-time.After(2 * time.Second): + } + return + } + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + + d := freshDaemon(srv.URL) + d.cfg.Agents = map[string]AgentEntry{"claude": {Path: "/usr/bin/true"}} + t.Cleanup(stubAgentVersion(t)) + + rootCtx, rootCancel := context.WithCancel(context.Background()) + t.Cleanup(rootCancel) + d.rootCtx = rootCtx + + d.workspaces["ws-1"] = &workspaceState{workspaceID: "ws-1", runtimeIDs: []string{"rt-old"}} + d.runtimeIndex["rt-old"] = Runtime{ID: "rt-old"} + + done := make(chan struct{}) + go func() { + d.handleRuntimeGone("rt-old") + close(done) + }() + + select { + case <-registerEntered: + case <-time.After(2 * time.Second): + t.Fatalf("register endpoint was never reached") + } + + rootCancel() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatalf("handleRuntimeGone did not abort after daemon root context cancellation") + } +} diff --git a/server/internal/daemon/wakeup.go b/server/internal/daemon/wakeup.go index 5defe0895..ad62c754c 100644 --- a/server/internal/daemon/wakeup.go +++ b/server/internal/daemon/wakeup.go @@ -240,12 +240,15 @@ func marshalRaw(v any) json.RawMessage { // route it through the same self-heal entry point as the HTTP path and do // NOT record a heartbeat freshness mark — pretending the runtime is alive // would let HTTP keep skipping its own heartbeat against the dead UUID. +// +// handleRuntimeGone uses the daemon root context for its register call, so +// this function can safely pass any caller context here. func (d *Daemon) handleWSHeartbeatAck(ctx context.Context, ack *HeartbeatResponse) { if ack == nil || ack.RuntimeID == "" { return } if ack.RuntimeGone { - go d.handleRuntimeGone(ctx, ack.RuntimeID) + go d.handleRuntimeGone(ack.RuntimeID) return } d.recordWSHeartbeatAck(ack.RuntimeID)