Compare commits

...

1 Commits

Author SHA1 Message Date
Jiang Bohan
d01891de9b refactor(daemon): consolidate task workspace resolver + regression test
Follow-up to #1249. Two small follow-ups requested in review:

1. `resolveTaskWorkspaceID` was duplicated between `handler/daemon.go` and
   `service/task.go`. #1249 fixed the handler copy but left both in place,
   meaning any future branch (e.g. a fourth task link type) still needs
   to be added in two files. Promote the service method to the exported
   `TaskService.ResolveTaskWorkspaceID` and delete the handler copy.
   Handler's `requireDaemonTaskAccess` and `ListTaskMessagesByUser` now
   call through `h.TaskService`.

2. Add a regression test `TestStartTask_AutopilotRunOnlyTask_ResolvesWorkspace`
   covering the exact scenario from #1224: a task linked only via
   `AutopilotRunID` must resolve to the autopilot's workspace. The test
   asserts 404 for a cross-workspace daemon token and 200 (with status
   transitioning to `running`) for the correct-workspace token.
2026-04-17 14:47:34 +08:00
3 changed files with 93 additions and 30 deletions

View File

@@ -67,7 +67,7 @@ func (h *Handler) requireDaemonTaskAccess(w http.ResponseWriter, r *http.Request
return db.AgentTaskQueue{}, false
}
wsID := h.resolveTaskWorkspaceID(r, task)
wsID := h.TaskService.ResolveTaskWorkspaceID(r.Context(), task)
if wsID == "" {
writeError(w, http.StatusNotFound, "task not found")
return db.AgentTaskQueue{}, false
@@ -96,29 +96,6 @@ func (h *Handler) verifyDaemonWorkspaceAccess(r *http.Request, workspaceID strin
return err == nil
}
// resolveTaskWorkspaceID derives the workspace ID from a task's linked entity
// (issue, chat session, or autopilot run).
func (h *Handler) resolveTaskWorkspaceID(r *http.Request, task db.AgentTaskQueue) string {
if task.IssueID.Valid {
if issue, err := h.Queries.GetIssue(r.Context(), task.IssueID); err == nil {
return uuidToString(issue.WorkspaceID)
}
}
if task.ChatSessionID.Valid {
if cs, err := h.Queries.GetChatSession(r.Context(), task.ChatSessionID); err == nil {
return uuidToString(cs.WorkspaceID)
}
}
if task.AutopilotRunID.Valid {
if run, err := h.Queries.GetAutopilotRun(r.Context(), task.AutopilotRunID); err == nil {
if ap, err := h.Queries.GetAutopilot(r.Context(), run.AutopilotID); err == nil {
return uuidToString(ap.WorkspaceID)
}
}
}
return ""
}
// ---------------------------------------------------------------------------
// Daemon Registration & Heartbeat
// ---------------------------------------------------------------------------
@@ -992,7 +969,7 @@ func (h *Handler) ListTaskMessagesByUser(w http.ResponseWriter, r *http.Request)
}
// Verify the task belongs to the caller's workspace.
wsID := h.resolveTaskWorkspaceID(r, task)
wsID := h.TaskService.ResolveTaskWorkspaceID(r.Context(), task)
if wsID == "" || wsID != middleware.WorkspaceIDFromContext(r.Context()) {
writeError(w, http.StatusNotFound, "task not found")
return

View File

@@ -644,3 +644,88 @@ func TestGetDaemonWorkspaceRepos_VersionIgnoresOrderAndDescription(t *testing.T)
t.Fatalf("expected repos_version to change when URL set changes, got %s", version3)
}
}
// Regression test for #1224: tasks linked only via AutopilotRunID (run_only
// autopilots) must resolve to the autopilot's workspace. Before the fix,
// resolveTaskWorkspaceID fell through and every StartTask call returned 404.
func TestStartTask_AutopilotRunOnlyTask_ResolvesWorkspace(t *testing.T) {
if testHandler == nil {
t.Skip("database not available")
}
ctx := context.Background()
var agentID, runtimeID string
if err := testPool.QueryRow(ctx, `
SELECT a.id, a.runtime_id FROM agent a WHERE a.workspace_id = $1 LIMIT 1
`, testWorkspaceID).Scan(&agentID, &runtimeID); err != nil {
t.Fatalf("setup: get agent: %v", err)
}
var autopilotID string
if err := testPool.QueryRow(ctx, `
INSERT INTO autopilot (
workspace_id, title, assignee_id, execution_mode,
created_by_type, created_by_id
)
VALUES ($1, 'run_only fixture', $2, 'run_only', 'member', $3)
RETURNING id
`, testWorkspaceID, agentID, testUserID).Scan(&autopilotID); err != nil {
t.Fatalf("setup: create autopilot: %v", err)
}
defer testPool.Exec(ctx, `DELETE FROM autopilot WHERE id = $1`, autopilotID)
var runID string
if err := testPool.QueryRow(ctx, `
INSERT INTO autopilot_run (autopilot_id, source, status)
VALUES ($1, 'manual', 'running')
RETURNING id
`, autopilotID).Scan(&runID); err != nil {
t.Fatalf("setup: create autopilot_run: %v", err)
}
// issue_id is explicitly NULL — the condition that used to trigger 404.
var taskID string
if err := testPool.QueryRow(ctx, `
INSERT INTO agent_task_queue (
agent_id, runtime_id, issue_id, status, priority, autopilot_run_id
)
VALUES ($1, $2, NULL, 'dispatched', 0, $3)
RETURNING id
`, agentID, runtimeID, runID).Scan(&taskID); err != nil {
t.Fatalf("setup: create autopilot task: %v", err)
}
defer testPool.Exec(ctx, `DELETE FROM agent_task_queue WHERE id = $1`, taskID)
// Cross-workspace daemon token must still 404.
w := httptest.NewRecorder()
req := newDaemonTokenRequest("POST", "/api/daemon/tasks/"+taskID+"/start", nil,
"00000000-0000-0000-0000-000000000000", "attacker-daemon")
rctx := chi.NewRouteContext()
rctx.URLParams.Add("taskId", taskID)
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
testHandler.StartTask(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("StartTask with cross-workspace token: expected 404, got %d: %s", w.Code, w.Body.String())
}
// Same-workspace daemon token must succeed — this is the bug in #1224.
w = httptest.NewRecorder()
req = newDaemonTokenRequest("POST", "/api/daemon/tasks/"+taskID+"/start", nil,
testWorkspaceID, "legit-daemon")
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
testHandler.StartTask(w, req)
if w.Code != http.StatusOK {
t.Fatalf("StartTask for run_only autopilot task: expected 200, got %d: %s", w.Code, w.Body.String())
}
var status string
if err := testPool.QueryRow(ctx, `SELECT status FROM agent_task_queue WHERE id = $1`, taskID).Scan(&status); err != nil {
t.Fatalf("post-check: read task status: %v", err)
}
if status != "running" {
t.Fatalf("expected task status 'running' after StartTask, got %q", status)
}
}

View File

@@ -479,7 +479,7 @@ func (s *TaskService) broadcastTaskDispatch(ctx context.Context, task db.AgentTa
payload["issue_id"] = util.UUIDToString(task.IssueID)
payload["agent_id"] = util.UUIDToString(task.AgentID)
workspaceID := s.resolveTaskWorkspaceID(ctx, task)
workspaceID := s.ResolveTaskWorkspaceID(ctx, task)
if workspaceID == "" {
return
}
@@ -493,7 +493,7 @@ func (s *TaskService) broadcastTaskDispatch(ctx context.Context, task db.AgentTa
}
func (s *TaskService) broadcastTaskEvent(ctx context.Context, eventType string, task db.AgentTaskQueue) {
workspaceID := s.resolveTaskWorkspaceID(ctx, task)
workspaceID := s.ResolveTaskWorkspaceID(ctx, task)
if workspaceID == "" {
return
}
@@ -515,10 +515,11 @@ func (s *TaskService) broadcastTaskEvent(ctx context.Context, eventType string,
})
}
// resolveTaskWorkspaceID determines the workspace ID for a task.
// ResolveTaskWorkspaceID determines the workspace ID for a task.
// For issue tasks, it comes from the issue. For chat tasks, from the chat session.
// For autopilot tasks, from the autopilot via its run.
func (s *TaskService) resolveTaskWorkspaceID(ctx context.Context, task db.AgentTaskQueue) string {
// Returns "" when none of the links resolve — callers treat that as "not found".
func (s *TaskService) ResolveTaskWorkspaceID(ctx context.Context, task db.AgentTaskQueue) string {
if task.IssueID.Valid {
if issue, err := s.Queries.GetIssue(ctx, task.IssueID); err == nil {
return util.UUIDToString(issue.WorkspaceID)
@@ -540,7 +541,7 @@ func (s *TaskService) resolveTaskWorkspaceID(ctx context.Context, task db.AgentT
}
func (s *TaskService) broadcastChatDone(ctx context.Context, task db.AgentTaskQueue) {
workspaceID := s.resolveTaskWorkspaceID(ctx, task)
workspaceID := s.ResolveTaskWorkspaceID(ctx, task)
if workspaceID == "" {
return
}