diff --git a/server/cmd/server/health_realtime.go b/server/cmd/server/health_realtime.go index 2a8748ae7..00ae3a084 100644 --- a/server/cmd/server/health_realtime.go +++ b/server/cmd/server/health_realtime.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" + "github.com/multica-ai/multica/server/internal/daemonws" "github.com/multica-ai/multica/server/internal/realtime" ) @@ -47,7 +48,9 @@ func realtimeMetricsHandler(token string) http.HandlerFunc { w.Header().Set("Content-Type", "application/json") w.Header().Set("Cache-Control", "no-store") - _ = json.NewEncoder(w).Encode(realtime.M.Snapshot()) + snapshot := realtime.M.Snapshot() + snapshot["daemonws"] = daemonws.M.Snapshot() + _ = json.NewEncoder(w).Encode(snapshot) } } diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go index acbb39d73..aa990e13a 100644 --- a/server/cmd/server/main.go +++ b/server/cmd/server/main.go @@ -12,6 +12,7 @@ import ( "time" "github.com/multica-ai/multica/server/internal/analytics" + "github.com/multica-ai/multica/server/internal/daemonws" "github.com/multica-ai/multica/server/internal/events" "github.com/multica-ai/multica/server/internal/logger" obsmetrics "github.com/multica-ai/multica/server/internal/metrics" @@ -161,6 +162,8 @@ func main() { bus := events.New() hub := realtime.NewHub() go hub.Run() + daemonHub := daemonws.NewHub() + var daemonWakeup service.TaskWakeupNotifier = daemonHub // MUL-1138: when REDIS_URL is set, route fanout through a Redis relay so // multiple API nodes can deliver each other's events. Without it the hub @@ -204,15 +207,21 @@ func main() { case "legacy": relayReadRedis = newNamedRedisClient(opts, "realtime-read") relay = realtime.NewRedisRelayWithClients(hub, relayWriteRedis, relayReadRedis) + slog.Info("daemon websocket wakeup: Redis fanout disabled in legacy realtime relay mode") case "dual": shardedReadRedis = newNamedRedisClient(opts, "realtime-read-sharded") legacyReadRedis = newNamedRedisClient(opts, "realtime-read-legacy") sharded := realtime.NewShardedStreamRelay(hub, relayWriteRedis, shardedReadRedis, relayConfig) + sharded.SetDaemonRuntimeDeliverer(daemonHub) legacy := realtime.NewRedisRelayWithClients(hub, relayWriteRedis, legacyReadRedis) relay = realtime.NewMirroredRelay(sharded, legacy) + daemonWakeup = daemonws.NewRelayNotifier(daemonHub, sharded) default: relayReadRedis = newNamedRedisClient(opts, "realtime-read") - relay = realtime.NewShardedStreamRelay(hub, relayWriteRedis, relayReadRedis, relayConfig) + sharded := realtime.NewShardedStreamRelay(hub, relayWriteRedis, relayReadRedis, relayConfig) + sharded.SetDaemonRuntimeDeliverer(daemonHub) + relay = sharded + daemonWakeup = daemonws.NewRelayNotifier(daemonHub, sharded) } relay.Start(relayCtx) broadcaster = realtime.NewDualWriteBroadcaster(hub, relay) @@ -253,6 +262,7 @@ func main() { metricsRegistry := obsmetrics.NewRegistry(obsmetrics.RegistryOptions{ Pool: pool, Realtime: realtime.M, + DaemonWS: daemonws.M, Version: version, Commit: commit, }) @@ -267,7 +277,9 @@ func main() { } r := NewRouterWithOptions(pool, hub, bus, analyticsClient, storeRedis, RouterOptions{ - HTTPMetrics: httpMetrics, + HTTPMetrics: httpMetrics, + DaemonHub: daemonHub, + DaemonWakeup: daemonWakeup, }) srv := &http.Server{ @@ -278,7 +290,7 @@ func main() { // Start background workers. sweepCtx, sweepCancel := context.WithCancel(context.Background()) autopilotCtx, autopilotCancel := context.WithCancel(context.Background()) - taskSvc := service.NewTaskService(queries, pool, hub, bus) + taskSvc := service.NewTaskService(queries, pool, hub, bus, daemonWakeup) autopilotSvc := service.NewAutopilotService(queries, pool, bus, taskSvc) registerAutopilotListeners(bus, autopilotSvc) diff --git a/server/cmd/server/router.go b/server/cmd/server/router.go index a948501d9..c60e5acf2 100644 --- a/server/cmd/server/router.go +++ b/server/cmd/server/router.go @@ -15,6 +15,7 @@ import ( "github.com/multica-ai/multica/server/internal/analytics" "github.com/multica-ai/multica/server/internal/auth" + "github.com/multica-ai/multica/server/internal/daemonws" "github.com/multica-ai/multica/server/internal/events" "github.com/multica-ai/multica/server/internal/handler" obsmetrics "github.com/multica-ai/multica/server/internal/metrics" @@ -67,12 +68,18 @@ func NewRouter(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus, analytics } type RouterOptions struct { - HTTPMetrics *obsmetrics.HTTPMetrics + HTTPMetrics *obsmetrics.HTTPMetrics + DaemonHub *daemonws.Hub + DaemonWakeup service.TaskWakeupNotifier } func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus, analyticsClient analytics.Client, rdb *redis.Client, opts RouterOptions) chi.Router { queries := db.New(pool) emailSvc := service.NewEmailService() + daemonHub := opts.DaemonHub + if daemonHub == nil { + daemonHub = daemonws.NewHub() + } // Initialize storage with S3 as primary, fallback to local var store storage.Storage @@ -93,7 +100,10 @@ func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus AllowedEmails: splitAndTrim(os.Getenv("ALLOWED_EMAILS")), AllowedEmailDomains: splitAndTrim(os.Getenv("ALLOWED_EMAIL_DOMAINS")), } - h := handler.New(queries, pool, hub, bus, emailSvc, store, cfSigner, analyticsClient, signupConfig) + h := handler.New(queries, pool, hub, bus, emailSvc, store, cfSigner, analyticsClient, signupConfig, daemonHub) + if opts.DaemonWakeup != nil { + h.TaskService.Wakeup = opts.DaemonWakeup + } if rdb != nil { h.LocalSkillListStore = handler.NewRedisLocalSkillListStore(rdb) h.LocalSkillImportStore = handler.NewRedisLocalSkillImportStore(rdb) @@ -178,6 +188,7 @@ func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus r.Post("/register", h.DaemonRegister) r.Post("/deregister", h.DaemonDeregister) r.Post("/heartbeat", h.DaemonHeartbeat) + r.Get("/ws", h.DaemonWebSocket) r.Get("/workspaces/{workspaceId}/repos", h.GetDaemonWorkspaceRepos) r.Post("/runtimes/{runtimeId}/tasks/claim", h.ClaimTaskByRuntime) diff --git a/server/internal/daemon/config.go b/server/internal/daemon/config.go index e52ebf386..10c306a60 100644 --- a/server/internal/daemon/config.go +++ b/server/internal/daemon/config.go @@ -12,7 +12,7 @@ import ( const ( DefaultServerURL = "ws://localhost:8080/ws" - DefaultPollInterval = 3 * time.Second + DefaultPollInterval = 30 * time.Second DefaultHeartbeatInterval = 15 * time.Second DefaultAgentTimeout = 2 * time.Hour DefaultCodexSemanticInactivityTimeout = 10 * time.Minute diff --git a/server/internal/daemon/daemon.go b/server/internal/daemon/daemon.go index 0d5460a4c..bbdc78ee1 100644 --- a/server/internal/daemon/daemon.go +++ b/server/internal/daemon/daemon.go @@ -45,6 +45,7 @@ type Daemon struct { workspaces map[string]*workspaceState runtimeIndex map[string]Runtime // runtimeID -> Runtime for provider lookups reloading sync.Mutex // prevents concurrent workspace syncs + runtimeSetCh chan struct{} // notifies the WS wakeup loop to reconnect with a new runtime set versionsMu sync.RWMutex // guards agentVersions agentVersions map[string]string // provider -> detected CLI version (set during registration) @@ -69,6 +70,7 @@ func New(cfg Config, logger *slog.Logger) *Daemon { logger: logger, workspaces: make(map[string]*workspaceState), runtimeIndex: make(map[string]Runtime), + runtimeSetCh: make(chan struct{}, 1), agentVersions: make(map[string]string), } } @@ -89,6 +91,23 @@ func (d *Daemon) agentVersion(provider string) string { return d.agentVersions[provider] } +func (d *Daemon) notifyRuntimeSetChanged() { + select { + case d.runtimeSetCh <- struct{}{}: + default: + } +} + +func (d *Daemon) drainRuntimeSetChanged() { + for { + select { + case <-d.runtimeSetCh: + default: + return + } + } +} + // Run starts the daemon: resolves auth, registers runtimes, then polls for tasks. func (d *Daemon) Run(ctx context.Context) error { // Wrap context so handleUpdate can cancel the daemon for restart. @@ -132,10 +151,13 @@ func (d *Daemon) Run(ctx context.Context) error { // Start workspace sync loop to discover newly created workspaces. go d.workspaceSyncLoop(ctx) + taskWakeups := make(chan struct{}, 1) + d.drainRuntimeSetChanged() + go d.taskWakeupLoop(ctx, taskWakeups) go d.heartbeatLoop(ctx) go d.gcLoop(ctx) go d.serveHealth(ctx, healthLn, time.Now()) - return d.pollLoop(ctx) + return d.pollLoop(ctx, taskWakeups) } // RestartBinary returns the path to the new binary if the daemon needs to restart @@ -422,6 +444,7 @@ func (d *Daemon) syncWorkspacesFromAPI(ctx context.Context) error { d.mu.Unlock() var registered int + var removed int for id, name := range apiIDs { if currentIDs[id] { continue // important: never replace existing workspaceState; ensureRepoReady holds ws.repoRefreshMu from the original pointer @@ -473,8 +496,12 @@ func (d *Daemon) syncWorkspacesFromAPI(ctx context.Context) error { delete(d.workspaces, id) d.mu.Unlock() d.logger.Info("stopped watching workspace", "workspace_id", id) + removed++ } } + if registered > 0 || removed > 0 { + d.notifyRuntimeSetChanged() + } if len(d.allRuntimeIDs()) == 0 && registered == 0 && len(workspaces) > 0 { return fmt.Errorf("failed to register runtimes for any of the %d workspace(s)", len(workspaces)) @@ -799,7 +826,7 @@ func (d *Daemon) triggerRestart() { } } -func (d *Daemon) pollLoop(ctx context.Context) error { +func (d *Daemon) pollLoop(ctx context.Context, taskWakeups <-chan struct{}) error { sem := make(chan struct{}, d.cfg.MaxConcurrentTasks) var wg sync.WaitGroup @@ -822,7 +849,7 @@ func (d *Daemon) pollLoop(ctx context.Context) error { runtimeIDs := d.allRuntimeIDs() if len(runtimeIDs) == 0 { - if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil { + if err := sleepWithContextOrWakeup(ctx, d.cfg.PollInterval, taskWakeups); err != nil { wg.Wait() return err } @@ -878,7 +905,7 @@ func (d *Daemon) pollLoop(ctx context.Context) error { d.logger.Debug("poll: no tasks", "runtimes", runtimeIDs, "cycle", pollCount) } pollOffset = (pollOffset + 1) % n - if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil { + if err := sleepWithContextOrWakeup(ctx, d.cfg.PollInterval, taskWakeups); err != nil { wg.Wait() return err } diff --git a/server/internal/daemon/helpers.go b/server/internal/daemon/helpers.go index 2e93ed7ed..4aa81781a 100644 --- a/server/internal/daemon/helpers.go +++ b/server/internal/daemon/helpers.go @@ -78,3 +78,21 @@ func sleepWithContext(ctx context.Context, d time.Duration) error { return nil } } + +func sleepWithContextOrWakeup(ctx context.Context, d time.Duration, wakeups <-chan struct{}) error { + if wakeups == nil { + return sleepWithContext(ctx, d) + } + + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-wakeups: + return nil + case <-timer.C: + return nil + } +} diff --git a/server/internal/daemon/wakeup.go b/server/internal/daemon/wakeup.go new file mode 100644 index 000000000..9ab2c58fb --- /dev/null +++ b/server/internal/daemon/wakeup.go @@ -0,0 +1,188 @@ +package daemon + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "github.com/gorilla/websocket" + "github.com/multica-ai/multica/server/pkg/protocol" +) + +var errRuntimeSetChanged = errors.New("runtime set changed") + +func (d *Daemon) taskWakeupLoop(ctx context.Context, taskWakeups chan<- struct{}) { + backoff := time.Second + + for { + runtimeIDs := d.allRuntimeIDs() + if len(runtimeIDs) == 0 { + if err := sleepWithContextOrRuntimeChange(ctx, 5*time.Second, d.runtimeSetCh); err != nil { + return + } + continue + } + + err := d.runTaskWakeupConnection(ctx, runtimeIDs, taskWakeups) + if ctx.Err() != nil { + return + } + if errors.Is(err, errRuntimeSetChanged) { + backoff = time.Second + continue + } + if err != nil { + d.logger.Debug("task wakeup websocket unavailable; polling fallback remains active", "error", err, "retry_in", backoff) + } + + if err := sleepWithContextOrRuntimeChange(ctx, jitterDuration(backoff), d.runtimeSetCh); err != nil { + return + } + if backoff < 30*time.Second { + backoff *= 2 + if backoff > 30*time.Second { + backoff = 30 * time.Second + } + } + } +} + +func jitterDuration(d time.Duration) time.Duration { + if d <= 0 { + return d + } + spread := d / 5 + if spread <= 0 { + return d + } + delta := time.Duration(rand.Int63n(int64(spread)*2+1)) - spread + return d + delta +} + +func (d *Daemon) runTaskWakeupConnection(ctx context.Context, runtimeIDs []string, taskWakeups chan<- struct{}) error { + wsURL, err := taskWakeupURL(d.cfg.ServerBaseURL, runtimeIDs) + if err != nil { + return err + } + + headers := http.Header{} + if token := d.client.Token(); token != "" { + headers.Set("Authorization", "Bearer "+token) + } + if d.client.platform != "" { + headers.Set("X-Client-Platform", d.client.platform) + } + if d.client.version != "" { + headers.Set("X-Client-Version", d.client.version) + } + if d.client.os != "" { + headers.Set("X-Client-OS", d.client.os) + } + + dialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second} + conn, _, err := dialer.DialContext(ctx, wsURL, headers) + if err != nil { + return err + } + defer conn.Close() + + d.logger.Info("task wakeup websocket connected", "runtimes", len(runtimeIDs)) + signalTaskWakeup(taskWakeups) + + errCh := make(chan error, 1) + go func() { + errCh <- d.readTaskWakeupMessages(conn, taskWakeups) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-d.runtimeSetCh: + return errRuntimeSetChanged + case err := <-errCh: + return err + } +} + +func (d *Daemon) readTaskWakeupMessages(conn *websocket.Conn, taskWakeups chan<- struct{}) error { + conn.SetReadLimit(64 * 1024) + for { + _, raw, err := conn.ReadMessage() + if err != nil { + return err + } + var msg protocol.Message + if err := json.Unmarshal(raw, &msg); err != nil { + d.logger.Debug("task wakeup websocket invalid message", "error", err) + continue + } + if msg.Type != protocol.EventDaemonTaskAvailable { + continue + } + var payload protocol.TaskAvailablePayload + if len(msg.Payload) > 0 { + if err := json.Unmarshal(msg.Payload, &payload); err != nil { + d.logger.Debug("task wakeup websocket invalid payload", "error", err) + continue + } + } + if payload.RuntimeID != "" { + d.logger.Debug("task wakeup received", "runtime_id", payload.RuntimeID, "task_id", payload.TaskID) + } + signalTaskWakeup(taskWakeups) + } +} + +func signalTaskWakeup(taskWakeups chan<- struct{}) { + select { + case taskWakeups <- struct{}{}: + default: + } +} + +func taskWakeupURL(baseURL string, runtimeIDs []string) (string, error) { + u, err := url.Parse(strings.TrimSpace(baseURL)) + if err != nil { + return "", fmt.Errorf("invalid daemon server URL: %w", err) + } + switch u.Scheme { + case "http": + u.Scheme = "ws" + case "https": + u.Scheme = "wss" + case "ws", "wss": + default: + return "", fmt.Errorf("daemon server URL must use http, https, ws, or wss") + } + + u.Path = strings.TrimRight(u.Path, "/") + "/api/daemon/ws" + u.RawPath = "" + q := u.Query() + ids := append([]string(nil), runtimeIDs...) + sort.Strings(ids) + q.Set("runtime_ids", strings.Join(ids, ",")) + u.RawQuery = q.Encode() + u.Fragment = "" + return u.String(), nil +} + +func sleepWithContextOrRuntimeChange(ctx context.Context, d time.Duration, runtimeSetCh <-chan struct{}) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-runtimeSetCh: + return nil + case <-timer.C: + return nil + } +} diff --git a/server/internal/daemon/wakeup_test.go b/server/internal/daemon/wakeup_test.go new file mode 100644 index 000000000..ce9960fc0 --- /dev/null +++ b/server/internal/daemon/wakeup_test.go @@ -0,0 +1,43 @@ +package daemon + +import "testing" + +func TestTaskWakeupURL(t *testing.T) { + tests := []struct { + name string + baseURL string + runtimeIDs []string + want string + }{ + { + name: "http base", + baseURL: "http://localhost:8080", + runtimeIDs: []string{"runtime-b", "runtime-a"}, + want: "ws://localhost:8080/api/daemon/ws?runtime_ids=runtime-a%2Cruntime-b", + }, + { + name: "https base", + baseURL: "https://api.example.com", + runtimeIDs: []string{"runtime-1"}, + want: "wss://api.example.com/api/daemon/ws?runtime_ids=runtime-1", + }, + { + name: "base path", + baseURL: "https://api.example.com/multica", + runtimeIDs: []string{"runtime-1"}, + want: "wss://api.example.com/multica/api/daemon/ws?runtime_ids=runtime-1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := taskWakeupURL(tt.baseURL, tt.runtimeIDs) + if err != nil { + t.Fatalf("taskWakeupURL: %v", err) + } + if got != tt.want { + t.Fatalf("taskWakeupURL() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/server/internal/daemonws/hub.go b/server/internal/daemonws/hub.go new file mode 100644 index 000000000..eeecdacc4 --- /dev/null +++ b/server/internal/daemonws/hub.go @@ -0,0 +1,349 @@ +package daemonws + +import ( + "encoding/json" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/multica-ai/multica/server/pkg/protocol" +) + +const ( + writeWait = 10 * time.Second + pongWait = 60 * time.Second + pingPeriod = (pongWait * 9) / 10 +) + +// ClientIdentity captures the already-authenticated daemon connection scope. +type ClientIdentity struct { + DaemonID string + UserID string + WorkspaceID string + RuntimeIDs []string + ClientVersion string +} + +type client struct { + hub *Hub + conn *websocket.Conn + send chan []byte + identity ClientIdentity + runtimes map[string]struct{} + + dedupMu sync.Mutex + seenIDs map[string]struct{} + seenList []string +} + +const eventDedupCapacity = 128 + +// markSeen records eventID as already delivered to this client. Empty event IDs +// disable dedup and are always delivered. +func (c *client) markSeen(eventID string) bool { + if eventID == "" { + return true + } + c.dedupMu.Lock() + defer c.dedupMu.Unlock() + if c.seenIDs == nil { + c.seenIDs = make(map[string]struct{}, eventDedupCapacity) + } + if _, ok := c.seenIDs[eventID]; ok { + return false + } + c.seenIDs[eventID] = struct{}{} + c.seenList = append(c.seenList, eventID) + if len(c.seenList) > eventDedupCapacity { + drop := c.seenList[0] + c.seenList = c.seenList[1:] + delete(c.seenIDs, drop) + } + return true +} + +// Hub keeps daemon WebSocket connections indexed by runtime ID. Messages are +// best-effort wakeup hints; the daemon still uses HTTP claim for correctness. +type Hub struct { + upgrader websocket.Upgrader + + mu sync.RWMutex + clients map[*client]bool + byRuntime map[string]map[*client]bool +} + +func NewHub() *Hub { + return &Hub{ + upgrader: websocket.Upgrader{ + // Daemon clients authenticate with Authorization headers before the + // upgrade. Browsers cannot set those headers through the native WS API, + // and DaemonAuth does not accept cookies, so cookie-based CSWSH does + // not apply to this endpoint. Re-evaluate this if DaemonAuth ever + // grows cookie fallback. + CheckOrigin: func(r *http.Request) bool { return true }, + }, + clients: make(map[*client]bool), + byRuntime: make(map[string]map[*client]bool), + } +} + +func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request, identity ClientIdentity) { + if len(identity.RuntimeIDs) == 0 { + http.Error(w, `{"error":"runtime_ids required"}`, http.StatusBadRequest) + return + } + + conn, err := h.upgrader.Upgrade(w, r, nil) + if err != nil { + slog.Error("daemon websocket upgrade failed", "error", err) + return + } + + runtimes := make(map[string]struct{}, len(identity.RuntimeIDs)) + for _, runtimeID := range identity.RuntimeIDs { + if runtimeID != "" { + runtimes[runtimeID] = struct{}{} + } + } + if len(runtimes) == 0 { + conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"runtime_ids required"}`)) + conn.Close() + return + } + + c := &client{ + hub: h, + conn: conn, + send: make(chan []byte, 16), + identity: identity, + runtimes: runtimes, + } + h.register(c) + + go c.writePump() + go c.readPump() +} + +// NotifyTaskAvailable sends a best-effort wakeup to daemons watching runtimeID. +func (h *Hub) NotifyTaskAvailable(runtimeID, taskID string) { + h.notifyTaskAvailable(runtimeID, taskID, "") +} + +func (h *Hub) notifyTaskAvailable(runtimeID, taskID, eventID string) { + if h == nil || runtimeID == "" { + return + } + data, err := taskAvailableFrame(runtimeID, taskID) + if err != nil { + return + } + delivered, deduped := h.notifyFrame(runtimeID, data, eventID) + if delivered { + M.WakeupDeliveredHit.Add(1) + } else if !deduped { + M.WakeupDeliveredMiss.Add(1) + } +} + +func (h *Hub) DeliverDaemonRuntime(scopeID string, frame []byte, eventID string) { + if h == nil { + return + } + M.WakeupReceivedTotal.Add(1) + var msg protocol.Message + if err := json.Unmarshal(frame, &msg); err != nil { + slog.Debug("daemon websocket relay: invalid frame", "error", err, "scope_id", scopeID, "event_id", eventID) + M.WakeupDeliveredMiss.Add(1) + return + } + if msg.Type != protocol.EventDaemonTaskAvailable { + M.WakeupDeliveredMiss.Add(1) + return + } + var payload protocol.TaskAvailablePayload + if err := json.Unmarshal(msg.Payload, &payload); err != nil || payload.RuntimeID == "" { + slog.Debug("daemon websocket relay: invalid task_available payload", "error", err, "scope_id", scopeID, "event_id", eventID) + M.WakeupDeliveredMiss.Add(1) + return + } + delivered, deduped := h.notifyFrame(payload.RuntimeID, frame, eventID) + if delivered { + M.WakeupDeliveredHit.Add(1) + } else if !deduped { + M.WakeupDeliveredMiss.Add(1) + } +} + +func (h *Hub) notifyFrame(runtimeID string, data []byte, eventID string) (delivered bool, deduped bool) { + h.mu.RLock() + clients := h.byRuntime[runtimeID] + slow := make([]*client, 0) + for c := range clients { + if !c.markSeen(eventID) { + deduped = true + continue + } + select { + case c.send <- data: + delivered = true + default: + slow = append(slow, c) + } + } + h.mu.RUnlock() + + for _, c := range slow { + h.unregister(c) + c.conn.Close() + } + if len(slow) > 0 { + M.SlowEvictionsTotal.Add(int64(len(slow))) + } + return delivered, deduped +} + +func taskAvailableFrame(runtimeID, taskID string) ([]byte, error) { + return json.Marshal(protocol.Message{ + Type: protocol.EventDaemonTaskAvailable, + Payload: mustMarshalRaw(protocol.TaskAvailablePayload{ + RuntimeID: runtimeID, + TaskID: taskID, + }), + }) +} + +func mustMarshalRaw(v any) json.RawMessage { + data, err := json.Marshal(v) + if err != nil { + return nil + } + return data +} + +func (h *Hub) RuntimeConnectionCount(runtimeID string) int { + h.mu.RLock() + defer h.mu.RUnlock() + return len(h.byRuntime[runtimeID]) +} + +func (h *Hub) register(c *client) { + h.mu.Lock() + h.clients[c] = true + for runtimeID := range c.runtimes { + conns := h.byRuntime[runtimeID] + if conns == nil { + conns = make(map[*client]bool) + h.byRuntime[runtimeID] = conns + } + conns[c] = true + } + total := len(h.clients) + h.mu.Unlock() + + M.ConnectsTotal.Add(1) + M.ActiveConnections.Add(1) + slog.Info("daemon websocket connected", + "daemon_id", c.identity.DaemonID, + "user_id", c.identity.UserID, + "workspace_id", c.identity.WorkspaceID, + "runtimes", len(c.runtimes), + "client_version", c.identity.ClientVersion, + "total_clients", total, + ) +} + +func (h *Hub) unregister(c *client) { + h.mu.Lock() + if !h.clients[c] { + h.mu.Unlock() + return + } + delete(h.clients, c) + for runtimeID := range c.runtimes { + if conns := h.byRuntime[runtimeID]; conns != nil { + delete(conns, c) + if len(conns) == 0 { + delete(h.byRuntime, runtimeID) + } + } + } + close(c.send) + total := len(h.clients) + h.mu.Unlock() + + M.DisconnectsTotal.Add(1) + M.ActiveConnections.Add(-1) + slog.Info("daemon websocket disconnected", + "daemon_id", c.identity.DaemonID, + "user_id", c.identity.UserID, + "workspace_id", c.identity.WorkspaceID, + "runtimes", len(c.runtimes), + "total_clients", total, + ) +} + +func (c *client) readPump() { + defer func() { + c.hub.unregister(c) + c.conn.Close() + }() + + c.conn.SetReadLimit(4096) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + + for { + _, raw, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + slog.Debug("daemon websocket read error", "error", err, "daemon_id", c.identity.DaemonID) + } + return + } + c.handleFrame(raw) + } +} + +func (c *client) handleFrame(raw []byte) { + var msg protocol.Message + if err := json.Unmarshal(raw, &msg); err != nil { + slog.Debug("daemon websocket invalid frame", "error", err, "daemon_id", c.identity.DaemonID) + return + } + // The phase-one daemon channel is server-push only. Inbound frames are + // drained so control frames and close handling work, but app messages are + // intentionally ignored for forward compatibility. +} + +func (c *client) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + + for { + select { + case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { + slog.Debug("daemon websocket write error", "error", err, "daemon_id", c.identity.DaemonID) + return + } + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} diff --git a/server/internal/daemonws/hub_test.go b/server/internal/daemonws/hub_test.go new file mode 100644 index 000000000..c522c8723 --- /dev/null +++ b/server/internal/daemonws/hub_test.go @@ -0,0 +1,200 @@ +package daemonws + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/multica-ai/multica/server/internal/realtime" + "github.com/multica-ai/multica/server/pkg/protocol" +) + +func TestNotifyTaskAvailable(t *testing.T) { + M.Reset() + defer M.Reset() + + hub := NewHub() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hub.HandleWebSocket(w, r, ClientIdentity{RuntimeIDs: []string{"runtime-1"}}) + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + deadline := time.Now().Add(time.Second) + for hub.RuntimeConnectionCount("runtime-1") == 0 { + if time.Now().After(deadline) { + t.Fatal("runtime connection was not registered") + } + time.Sleep(10 * time.Millisecond) + } + + hub.NotifyTaskAvailable("runtime-1", "task-1") + + if err := conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil { + t.Fatalf("SetReadDeadline: %v", err) + } + _, raw, err := conn.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage: %v", err) + } + + var msg protocol.Message + if err := json.Unmarshal(raw, &msg); err != nil { + t.Fatalf("unmarshal message: %v", err) + } + if msg.Type != protocol.EventDaemonTaskAvailable { + t.Fatalf("message type = %q, want %q", msg.Type, protocol.EventDaemonTaskAvailable) + } + + var payload protocol.TaskAvailablePayload + if err := json.Unmarshal(msg.Payload, &payload); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if payload.RuntimeID != "runtime-1" || payload.TaskID != "task-1" { + t.Fatalf("payload = %+v, want runtime/task IDs", payload) + } +} + +func TestRelayNotifierPublishesDaemonRuntimeScope(t *testing.T) { + M.Reset() + defer M.Reset() + + relay := &recordingRelayPublisher{} + notifier := NewRelayNotifier(nil, relay) + + notifier.NotifyTaskAvailable("runtime-1", "task-1") + + if relay.scopeType != realtime.ScopeDaemonRuntime { + t.Fatalf("scopeType = %q, want %q", relay.scopeType, realtime.ScopeDaemonRuntime) + } + if relay.scopeID != "task-1" { + t.Fatalf("scopeID = %q, want task_id shard key", relay.scopeID) + } + if relay.eventID == "" { + t.Fatal("expected event id") + } + if M.WakeupPublishedTotal.Load() != 1 { + t.Fatalf("published metric = %d, want 1", M.WakeupPublishedTotal.Load()) + } + + var msg protocol.Message + if err := json.Unmarshal(relay.frame, &msg); err != nil { + t.Fatalf("unmarshal frame: %v", err) + } + if msg.Type != protocol.EventDaemonTaskAvailable { + t.Fatalf("message type = %q, want %q", msg.Type, protocol.EventDaemonTaskAvailable) + } + var payload protocol.TaskAvailablePayload + if err := json.Unmarshal(msg.Payload, &payload); err != nil { + t.Fatalf("unmarshal payload: %v", err) + } + if payload.RuntimeID != "runtime-1" || payload.TaskID != "task-1" { + t.Fatalf("payload = %+v, want runtime/task IDs", payload) + } +} + +func TestRelayNotifierDedupsLocalRedisLoopback(t *testing.T) { + M.Reset() + defer M.Reset() + + hub := NewHub() + client := attachDaemonTestClient(hub, "runtime-1") + relay := &localFirstDaemonRelayPublisher{t: t, client: client} + notifier := NewRelayNotifier(hub, relay) + + notifier.NotifyTaskAvailable("runtime-1", "task-1") + + if !relay.called { + t.Fatal("expected relay publish to be invoked") + } + if relay.eventID == "" { + t.Fatal("expected event id") + } + if M.WakeupDeliveredHit.Load() != 1 { + t.Fatalf("delivered hit metric = %d, want 1", M.WakeupDeliveredHit.Load()) + } + + hub.DeliverDaemonRuntime(relay.scopeID, relay.frame, relay.eventID) + + select { + case duplicate := <-client.send: + t.Fatalf("expected redis loopback to be deduped, got duplicate %s", duplicate) + case <-time.After(20 * time.Millisecond): + } + if M.WakeupDeliveredHit.Load() != 1 { + t.Fatalf("delivered hit metric after loopback = %d, want 1", M.WakeupDeliveredHit.Load()) + } + if M.WakeupDeliveredMiss.Load() != 0 { + t.Fatalf("delivered miss metric after dedup = %d, want 0", M.WakeupDeliveredMiss.Load()) + } +} + +func attachDaemonTestClient(hub *Hub, runtimeID string) *client { + c := &client{ + send: make(chan []byte, 2), + runtimes: map[string]struct{}{runtimeID: {}}, + } + + hub.mu.Lock() + hub.clients[c] = true + hub.byRuntime[runtimeID] = map[*client]bool{c: true} + hub.mu.Unlock() + + return c +} + +type recordingRelayPublisher struct { + scopeType string + scopeID string + exclude string + frame []byte + eventID string +} + +func (r *recordingRelayPublisher) PublishWithID(scopeType, scopeID, exclude string, frame []byte, id string) error { + r.scopeType = scopeType + r.scopeID = scopeID + r.exclude = exclude + r.frame = append([]byte(nil), frame...) + r.eventID = id + return nil +} + +type localFirstDaemonRelayPublisher struct { + t *testing.T + client *client + + called bool + scopeType string + scopeID string + exclude string + frame []byte + eventID string + localFrame []byte +} + +func (p *localFirstDaemonRelayPublisher) PublishWithID(scopeType, scopeID, exclude string, frame []byte, id string) error { + p.called = true + p.scopeType = scopeType + p.scopeID = scopeID + p.exclude = exclude + p.frame = append([]byte(nil), frame...) + p.eventID = id + + select { + case p.localFrame = <-p.client.send: + default: + p.t.Fatal("expected local fanout to happen before relay publish") + } + return nil +} diff --git a/server/internal/daemonws/metrics.go b/server/internal/daemonws/metrics.go new file mode 100644 index 000000000..63bd2ce77 --- /dev/null +++ b/server/internal/daemonws/metrics.go @@ -0,0 +1,44 @@ +package daemonws + +import "sync/atomic" + +type Metrics struct { + ConnectsTotal atomic.Int64 + DisconnectsTotal atomic.Int64 + ActiveConnections atomic.Int64 + SlowEvictionsTotal atomic.Int64 + + WakeupPublishedTotal atomic.Int64 + WakeupPublishErrors atomic.Int64 + WakeupReceivedTotal atomic.Int64 + WakeupDeliveredHit atomic.Int64 + WakeupDeliveredMiss atomic.Int64 +} + +var M = &Metrics{} + +func (m *Metrics) Snapshot() map[string]any { + return map[string]any{ + "connects_total": m.ConnectsTotal.Load(), + "disconnects_total": m.DisconnectsTotal.Load(), + "active_connections": m.ActiveConnections.Load(), + "slow_evictions_total": m.SlowEvictionsTotal.Load(), + "wakeup_published_total": m.WakeupPublishedTotal.Load(), + "wakeup_publish_errors": m.WakeupPublishErrors.Load(), + "wakeup_received_total": m.WakeupReceivedTotal.Load(), + "wakeup_delivered_hit_total": m.WakeupDeliveredHit.Load(), + "wakeup_delivered_miss_total": m.WakeupDeliveredMiss.Load(), + } +} + +func (m *Metrics) Reset() { + m.ConnectsTotal.Store(0) + m.DisconnectsTotal.Store(0) + m.ActiveConnections.Store(0) + m.SlowEvictionsTotal.Store(0) + m.WakeupPublishedTotal.Store(0) + m.WakeupPublishErrors.Store(0) + m.WakeupReceivedTotal.Store(0) + m.WakeupDeliveredHit.Store(0) + m.WakeupDeliveredMiss.Store(0) +} diff --git a/server/internal/daemonws/notifier.go b/server/internal/daemonws/notifier.go new file mode 100644 index 000000000..c1d1f51cd --- /dev/null +++ b/server/internal/daemonws/notifier.go @@ -0,0 +1,49 @@ +package daemonws + +import ( + "log/slog" + + "github.com/oklog/ulid/v2" + + "github.com/multica-ai/multica/server/internal/realtime" +) + +// RelayNotifier sends task wakeups to the local daemon hub and, when Redis is +// configured, publishes the same wakeup through the shared realtime relay so +// every API node can attempt local delivery. +type RelayNotifier struct { + local *Hub + relay realtime.RelayPublisher +} + +func NewRelayNotifier(local *Hub, relay realtime.RelayPublisher) *RelayNotifier { + return &RelayNotifier{local: local, relay: relay} +} + +func (n *RelayNotifier) NotifyTaskAvailable(runtimeID, taskID string) { + if runtimeID == "" { + return + } + eventID := ulid.Make().String() + if n.local != nil { + n.local.notifyTaskAvailable(runtimeID, taskID, eventID) + } + if n.relay == nil { + return + } + frame, err := taskAvailableFrame(runtimeID, taskID) + if err != nil { + M.WakeupPublishErrors.Add(1) + return + } + shardKey := taskID + if shardKey == "" { + shardKey = eventID + } + if err := n.relay.PublishWithID(realtime.ScopeDaemonRuntime, shardKey, "", frame, eventID); err != nil { + M.WakeupPublishErrors.Add(1) + slog.Warn("daemon websocket wakeup publish failed", "error", err, "runtime_id", runtimeID, "task_id", taskID) + return + } + M.WakeupPublishedTotal.Add(1) +} diff --git a/server/internal/handler/daemon_ws.go b/server/internal/handler/daemon_ws.go new file mode 100644 index 000000000..869e1a13d --- /dev/null +++ b/server/internal/handler/daemon_ws.go @@ -0,0 +1,66 @@ +package handler + +import ( + "net/http" + "strings" + + "github.com/multica-ai/multica/server/internal/daemonws" + "github.com/multica-ai/multica/server/internal/middleware" +) + +func (h *Handler) DaemonWebSocket(w http.ResponseWriter, r *http.Request) { + if h.DaemonHub == nil { + writeError(w, http.StatusServiceUnavailable, "daemon websocket unavailable") + return + } + + runtimeIDs := parseRuntimeIDs(r) + if len(runtimeIDs) == 0 { + writeError(w, http.StatusBadRequest, "runtime_ids required") + return + } + + for _, runtimeID := range runtimeIDs { + rt, ok := h.requireDaemonRuntimeAccess(w, r, runtimeID) + if !ok { + return + } + if daemonID := middleware.DaemonIDFromContext(r.Context()); daemonID != "" && rt.DaemonID.Valid && rt.DaemonID.String != daemonID { + writeError(w, http.StatusNotFound, "runtime not found") + return + } + } + + h.DaemonHub.HandleWebSocket(w, r, daemonws.ClientIdentity{ + DaemonID: middleware.DaemonIDFromContext(r.Context()), + UserID: requestUserID(r), + WorkspaceID: middleware.DaemonWorkspaceIDFromContext(r.Context()), + RuntimeIDs: runtimeIDs, + ClientVersion: r.Header.Get("X-Client-Version"), + }) +} + +func parseRuntimeIDs(r *http.Request) []string { + seen := map[string]struct{}{} + var out []string + add := func(raw string) { + for _, part := range strings.Split(raw, ",") { + id := strings.TrimSpace(part) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + } + for _, raw := range r.URL.Query()["runtime_id"] { + add(raw) + } + for _, raw := range r.URL.Query()["runtime_ids"] { + add(raw) + } + return out +} diff --git a/server/internal/handler/handler.go b/server/internal/handler/handler.go index 3b89a4c95..039e3cce2 100644 --- a/server/internal/handler/handler.go +++ b/server/internal/handler/handler.go @@ -15,6 +15,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/multica-ai/multica/server/internal/analytics" "github.com/multica-ai/multica/server/internal/auth" + "github.com/multica-ai/multica/server/internal/daemonws" "github.com/multica-ai/multica/server/internal/events" "github.com/multica-ai/multica/server/internal/middleware" "github.com/multica-ai/multica/server/internal/realtime" @@ -53,6 +54,7 @@ type Handler struct { DB dbExecutor TxStarter txStarter Hub *realtime.Hub + DaemonHub *daemonws.Hub Bus *events.Bus TaskService *service.TaskService AutopilotService *service.AutopilotService @@ -67,7 +69,7 @@ type Handler struct { cfg Config } -func New(queries *db.Queries, txStarter txStarter, hub *realtime.Hub, bus *events.Bus, emailService *service.EmailService, store storage.Storage, cfSigner *auth.CloudFrontSigner, analyticsClient analytics.Client, cfg Config) *Handler { +func New(queries *db.Queries, txStarter txStarter, hub *realtime.Hub, bus *events.Bus, emailService *service.EmailService, store storage.Storage, cfSigner *auth.CloudFrontSigner, analyticsClient analytics.Client, cfg Config, daemonHubs ...*daemonws.Hub) *Handler { var executor dbExecutor if candidate, ok := txStarter.(dbExecutor); ok { executor = candidate @@ -77,12 +79,18 @@ func New(queries *db.Queries, txStarter txStarter, hub *realtime.Hub, bus *event analyticsClient = analytics.NoopClient{} } - taskSvc := service.NewTaskService(queries, txStarter, hub, bus) + var daemonHub *daemonws.Hub + if len(daemonHubs) > 0 { + daemonHub = daemonHubs[0] + } + + taskSvc := service.NewTaskService(queries, txStarter, hub, bus, daemonHub) return &Handler{ Queries: queries, DB: executor, TxStarter: txStarter, Hub: hub, + DaemonHub: daemonHub, Bus: bus, TaskService: taskSvc, AutopilotService: service.NewAutopilotService(queries, txStarter, bus, taskSvc), diff --git a/server/internal/metrics/daemonws.go b/server/internal/metrics/daemonws.go new file mode 100644 index 000000000..d935ddc1a --- /dev/null +++ b/server/internal/metrics/daemonws.go @@ -0,0 +1,70 @@ +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + + "github.com/multica-ai/multica/server/internal/daemonws" +) + +type DaemonWSCollector struct { + metrics *daemonws.Metrics + + connectsTotal *prometheus.Desc + disconnectsTotal *prometheus.Desc + activeConnections *prometheus.Desc + slowEvictionsTotal *prometheus.Desc + wakeupPublishedTotal *prometheus.Desc + wakeupPublishErrors *prometheus.Desc + wakeupReceivedTotal *prometheus.Desc + wakeupDeliveredTotal *prometheus.Desc +} + +func NewDaemonWSCollector(m *daemonws.Metrics) *DaemonWSCollector { + return &DaemonWSCollector{ + metrics: m, + + connectsTotal: newDaemonWSDesc("connects_total", "Total daemon WebSocket connections opened."), + disconnectsTotal: newDaemonWSDesc("disconnects_total", "Total daemon WebSocket connections closed."), + activeConnections: newDaemonWSDesc("active_connections", "Current daemon WebSocket connections."), + slowEvictionsTotal: newDaemonWSDesc("slow_evictions_total", "Total daemon WebSocket clients evicted for slow consumption."), + wakeupPublishedTotal: newDaemonWSDesc("wakeup_published_total", "Total daemon wakeups published to the Redis relay."), + wakeupPublishErrors: newDaemonWSDesc("wakeup_publish_errors_total", "Total daemon wakeup Redis publish errors."), + wakeupReceivedTotal: newDaemonWSDesc("wakeup_received_total", "Total daemon wakeups received from the Redis relay."), + wakeupDeliveredTotal: prometheus.NewDesc("multica_daemonws_wakeup_delivered_total", "Total daemon wakeup local delivery attempts.", []string{"result"}, nil), + } +} + +func newDaemonWSDesc(name, help string) *prometheus.Desc { + return prometheus.NewDesc("multica_daemonws_"+name, help, nil, nil) +} + +func (c *DaemonWSCollector) Describe(ch chan<- *prometheus.Desc) { + for _, desc := range []*prometheus.Desc{ + c.connectsTotal, + c.disconnectsTotal, + c.activeConnections, + c.slowEvictionsTotal, + c.wakeupPublishedTotal, + c.wakeupPublishErrors, + c.wakeupReceivedTotal, + c.wakeupDeliveredTotal, + } { + ch <- desc + } +} + +func (c *DaemonWSCollector) Collect(ch chan<- prometheus.Metric) { + if c.metrics == nil { + return + } + m := c.metrics + ch <- prometheus.MustNewConstMetric(c.connectsTotal, prometheus.CounterValue, float64(m.ConnectsTotal.Load())) + ch <- prometheus.MustNewConstMetric(c.disconnectsTotal, prometheus.CounterValue, float64(m.DisconnectsTotal.Load())) + ch <- prometheus.MustNewConstMetric(c.activeConnections, prometheus.GaugeValue, float64(m.ActiveConnections.Load())) + ch <- prometheus.MustNewConstMetric(c.slowEvictionsTotal, prometheus.CounterValue, float64(m.SlowEvictionsTotal.Load())) + ch <- prometheus.MustNewConstMetric(c.wakeupPublishedTotal, prometheus.CounterValue, float64(m.WakeupPublishedTotal.Load())) + ch <- prometheus.MustNewConstMetric(c.wakeupPublishErrors, prometheus.CounterValue, float64(m.WakeupPublishErrors.Load())) + ch <- prometheus.MustNewConstMetric(c.wakeupReceivedTotal, prometheus.CounterValue, float64(m.WakeupReceivedTotal.Load())) + ch <- prometheus.MustNewConstMetric(c.wakeupDeliveredTotal, prometheus.CounterValue, float64(m.WakeupDeliveredHit.Load()), "hit") + ch <- prometheus.MustNewConstMetric(c.wakeupDeliveredTotal, prometheus.CounterValue, float64(m.WakeupDeliveredMiss.Load()), "miss") +} diff --git a/server/internal/metrics/registry.go b/server/internal/metrics/registry.go index 4b03e8311..776dd1ff2 100644 --- a/server/internal/metrics/registry.go +++ b/server/internal/metrics/registry.go @@ -7,12 +7,14 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/multica-ai/multica/server/internal/daemonws" "github.com/multica-ai/multica/server/internal/realtime" ) type RegistryOptions struct { Pool *pgxpool.Pool Realtime *realtime.Metrics + DaemonWS *daemonws.Metrics Version string Commit string } @@ -43,6 +45,9 @@ func NewRegistry(opts RegistryOptions) *Registry { if opts.Realtime != nil { reg.MustRegister(NewRealtimeCollector(opts.Realtime)) } + if opts.DaemonWS != nil { + reg.MustRegister(NewDaemonWSCollector(opts.DaemonWS)) + } return &Registry{ Gatherer: reg, diff --git a/server/internal/realtime/broadcaster.go b/server/internal/realtime/broadcaster.go index a5aaba7ad..a90c7ee67 100644 --- a/server/internal/realtime/broadcaster.go +++ b/server/internal/realtime/broadcaster.go @@ -8,6 +8,9 @@ const ( ScopeUser = "user" ScopeTask = "task" ScopeChat = "chat" + // ScopeDaemonRuntime routes daemon wakeup frames through the Redis relay. + // It is consumed by the daemon WebSocket hub, not by browser clients. + ScopeDaemonRuntime = "daemon_runtime" ) // Broadcaster is the abstraction every realtime event producer should depend @@ -38,5 +41,10 @@ type Broadcaster interface { Broadcast(message []byte) } +// DaemonRuntimeDeliverer consumes daemon-runtime scoped relay frames. +type DaemonRuntimeDeliverer interface { + DeliverDaemonRuntime(scopeID string, frame []byte, eventID string) +} + // Compile-time assertion that *Hub continues to satisfy Broadcaster. var _ Broadcaster = (*Hub)(nil) diff --git a/server/internal/realtime/redis_relay.go b/server/internal/realtime/redis_relay.go index 8faf728b0..00a8cf8fa 100644 --- a/server/internal/realtime/redis_relay.go +++ b/server/internal/realtime/redis_relay.go @@ -106,12 +106,16 @@ func redisString(v any) string { } } -func deliverEnvelope(hub *Hub, ev envelope) { +func deliverEnvelope(hub *Hub, daemonRuntime DaemonRuntimeDeliverer, ev envelope) { if ev.PayloadJSON == "" { return } frame := injectEventID([]byte(ev.PayloadJSON), ev.EventID) switch ev.Scope { + case ScopeDaemonRuntime: + if daemonRuntime != nil { + daemonRuntime.DeliverDaemonRuntime(ev.ScopeID, frame, ev.EventID) + } case "global": hub.fanoutAllDedup(frame, "", ev.EventID) case ScopeUser: @@ -134,6 +138,8 @@ type RedisRelay struct { consumers map[scopeKey]*scopeConsumer stopping bool wg sync.WaitGroup + + daemonRuntime DaemonRuntimeDeliverer } type scopeConsumer struct { @@ -167,6 +173,10 @@ func NewRedisRelayWithClients(hub *Hub, writeRDB, readRDB *redis.Client) *RedisR // NodeID returns this relay's randomly-assigned node identifier. func (r *RedisRelay) NodeID() string { return r.nodeID } +func (r *RedisRelay) SetDaemonRuntimeDeliverer(d DaemonRuntimeDeliverer) { + r.daemonRuntime = d +} + // Wait blocks until all relay-owned goroutines have exited after the Start // context is canceled. func (r *RedisRelay) Wait() { @@ -394,7 +404,7 @@ func (r *RedisRelay) deliverMessage(scopeType, scopeID string, msg redis.XMessag if ev.ScopeID == "" { ev.ScopeID = scopeID } - deliverEnvelope(r.hub, ev) + deliverEnvelope(r.hub, r.daemonRuntime, ev) } // fanoutUser is implemented in hub.go. diff --git a/server/internal/realtime/relay_lifecycle.go b/server/internal/realtime/relay_lifecycle.go index ddd742d21..eda9bc945 100644 --- a/server/internal/realtime/relay_lifecycle.go +++ b/server/internal/realtime/relay_lifecycle.go @@ -36,6 +36,15 @@ func (r *MirroredRelay) NodeID() string { return r.primary.NodeID() } +func (r *MirroredRelay) SetDaemonRuntimeDeliverer(d DaemonRuntimeDeliverer) { + if setter, ok := r.primary.(interface{ SetDaemonRuntimeDeliverer(DaemonRuntimeDeliverer) }); ok { + setter.SetDaemonRuntimeDeliverer(d) + } + if setter, ok := r.mirror.(interface{ SetDaemonRuntimeDeliverer(DaemonRuntimeDeliverer) }); ok { + setter.SetDaemonRuntimeDeliverer(d) + } +} + func (r *MirroredRelay) Start(ctx context.Context) { r.primary.Start(ctx) r.mirror.Start(ctx) @@ -74,6 +83,9 @@ func (r *MirroredRelay) Broadcast(message []byte) { func (r *MirroredRelay) PublishWithID(scopeType, scopeID, exclude string, frame []byte, id string) error { primaryErr := r.primary.PublishWithID(scopeType, scopeID, exclude, frame, id) + if scopeType == ScopeDaemonRuntime { + return primaryErr + } mirrorErr := r.mirror.PublishWithID(scopeType, scopeID, exclude, frame, id) if primaryErr != nil { diff --git a/server/internal/realtime/relay_lifecycle_test.go b/server/internal/realtime/relay_lifecycle_test.go index 8c035a828..8e38fa008 100644 --- a/server/internal/realtime/relay_lifecycle_test.go +++ b/server/internal/realtime/relay_lifecycle_test.go @@ -50,6 +50,23 @@ func TestMirroredRelayRecordsDivergenceWhenOneBackendFails(t *testing.T) { } } +func TestMirroredRelayDoesNotMirrorDaemonRuntimeEvents(t *testing.T) { + primary := &recordingManagedRelay{nodeID: "primary"} + mirror := &recordingManagedRelay{nodeID: "mirror"} + relay := NewMirroredRelay(primary, mirror) + + if err := relay.PublishWithID(ScopeDaemonRuntime, "task-1", "", []byte(`{"type":"daemon:task_available"}`), "event-1"); err != nil { + t.Fatalf("PublishWithID: %v", err) + } + + if len(primary.calls) != 1 { + t.Fatalf("expected primary publish call, got %d", len(primary.calls)) + } + if len(mirror.calls) != 0 { + t.Fatalf("expected daemon runtime event not to hit mirror, got %d calls", len(mirror.calls)) + } +} + type relayPublishCall struct { scopeType string scopeID string diff --git a/server/internal/realtime/sharded_stream_relay.go b/server/internal/realtime/sharded_stream_relay.go index 41d1c4e28..04c3da722 100644 --- a/server/internal/realtime/sharded_stream_relay.go +++ b/server/internal/realtime/sharded_stream_relay.go @@ -76,6 +76,8 @@ type ShardedStreamRelay struct { mu sync.Mutex stopping bool wg sync.WaitGroup + + daemonRuntime DaemonRuntimeDeliverer } func NewShardedStreamRelay(hub *Hub, writeRDB, readRDB *redis.Client, config ShardedStreamRelayConfig) *ShardedStreamRelay { @@ -93,6 +95,10 @@ func NewShardedStreamRelay(hub *Hub, writeRDB, readRDB *redis.Client, config Sha func (r *ShardedStreamRelay) NodeID() string { return r.nodeID } +func (r *ShardedStreamRelay) SetDaemonRuntimeDeliverer(d DaemonRuntimeDeliverer) { + r.daemonRuntime = d +} + func (r *ShardedStreamRelay) Start(ctx context.Context) { M.NodeID.Store(r.nodeID) if err := r.writeRDB.Ping(ctx).Err(); err != nil { @@ -233,7 +239,7 @@ func (r *ShardedStreamRelay) deliverMessage(msg redis.XMessage) { if !ok || ev.Scope == "" || ev.ScopeID == "" { return } - deliverEnvelope(r.hub, ev) + deliverEnvelope(r.hub, r.daemonRuntime, ev) } func (r *ShardedStreamRelay) heartbeatLoop(ctx context.Context) { diff --git a/server/internal/service/task.go b/server/internal/service/task.go index 52d1ed7b1..09cfb3efd 100644 --- a/server/internal/service/task.go +++ b/server/internal/service/task.go @@ -26,10 +26,19 @@ type TaskService struct { TxStarter TxStarter Hub *realtime.Hub Bus *events.Bus + Wakeup TaskWakeupNotifier } -func NewTaskService(q *db.Queries, tx TxStarter, hub *realtime.Hub, bus *events.Bus) *TaskService { - return &TaskService{Queries: q, TxStarter: tx, Hub: hub, Bus: bus} +type TaskWakeupNotifier interface { + NotifyTaskAvailable(runtimeID, taskID string) +} + +func NewTaskService(q *db.Queries, tx TxStarter, hub *realtime.Hub, bus *events.Bus, wakeups ...TaskWakeupNotifier) *TaskService { + var wakeup TaskWakeupNotifier + if len(wakeups) > 0 { + wakeup = wakeups[0] + } + return &TaskService{Queries: q, TxStarter: tx, Hub: hub, Bus: bus, Wakeup: wakeup} } // EnqueueTaskForIssue creates a queued task for an agent-assigned issue. @@ -73,6 +82,7 @@ func (s *TaskService) EnqueueTaskForIssue(ctx context.Context, issue db.Issue, t } slog.Info("task enqueued", "task_id", util.UUIDToString(task.ID), "issue_id", util.UUIDToString(issue.ID), "agent_id", util.UUIDToString(issue.AssigneeID)) + s.notifyTaskAvailable(task) return task, nil } @@ -107,6 +117,7 @@ func (s *TaskService) EnqueueTaskForMention(ctx context.Context, issue db.Issue, } slog.Info("mention task enqueued", "task_id", util.UUIDToString(task.ID), "issue_id", util.UUIDToString(issue.ID), "agent_id", util.UUIDToString(agentID)) + s.notifyTaskAvailable(task) return task, nil } @@ -137,6 +148,7 @@ func (s *TaskService) EnqueueChatTask(ctx context.Context, chatSession db.ChatSe } slog.Info("chat task enqueued", "task_id", util.UUIDToString(task.ID), "chat_session_id", util.UUIDToString(chatSession.ID), "agent_id", util.UUIDToString(chatSession.AgentID)) + s.notifyTaskAvailable(task) return task, nil } @@ -645,6 +657,7 @@ func (s *TaskService) MaybeRetryFailedTask(ctx context.Context, parent db.AgentT "attempt", child.Attempt, "max_attempts", child.MaxAttempts, ) + s.notifyTaskAvailable(child) s.broadcastTaskEvent(ctx, protocol.EventTaskDispatch, child) return &child, nil } @@ -902,6 +915,13 @@ func priorityToInt(p string) int32 { } } +func (s *TaskService) notifyTaskAvailable(task db.AgentTaskQueue) { + if s.Wakeup == nil || !task.RuntimeID.Valid { + return + } + s.Wakeup.NotifyTaskAvailable(util.UUIDToString(task.RuntimeID), util.UUIDToString(task.ID)) +} + func (s *TaskService) broadcastTaskDispatch(ctx context.Context, task db.AgentTaskQueue) { var payload map[string]any if task.Context != nil { diff --git a/server/pkg/protocol/events.go b/server/pkg/protocol/events.go index b5ba64e83..1ca5d9ee7 100644 --- a/server/pkg/protocol/events.go +++ b/server/pkg/protocol/events.go @@ -11,10 +11,10 @@ const ( EventCommentCreated = "comment:created" EventCommentUpdated = "comment:updated" EventCommentDeleted = "comment:deleted" - EventReactionAdded = "reaction:added" - EventReactionRemoved = "reaction:removed" - EventIssueReactionAdded = "issue_reaction:added" - EventIssueReactionRemoved = "issue_reaction:removed" + EventReactionAdded = "reaction:added" + EventReactionRemoved = "reaction:removed" + EventIssueReactionAdded = "issue_reaction:added" + EventIssueReactionRemoved = "issue_reaction:removed" // Agent events EventAgentStatus = "agent:status" @@ -93,6 +93,7 @@ const ( EventAutopilotRunDone = "autopilot:run_done" // Daemon events - EventDaemonHeartbeat = "daemon:heartbeat" - EventDaemonRegister = "daemon:register" + EventDaemonHeartbeat = "daemon:heartbeat" + EventDaemonRegister = "daemon:register" + EventDaemonTaskAvailable = "daemon:task_available" ) diff --git a/server/pkg/protocol/messages.go b/server/pkg/protocol/messages.go index 3e11e8eb8..d98d01456 100644 --- a/server/pkg/protocol/messages.go +++ b/server/pkg/protocol/messages.go @@ -16,6 +16,13 @@ type TaskDispatchPayload struct { Description string `json:"description"` } +// TaskAvailablePayload is sent from server to daemon as a wakeup hint. The +// daemon still claims work through the existing HTTP claim endpoint. +type TaskAvailablePayload struct { + RuntimeID string `json:"runtime_id"` + TaskID string `json:"task_id,omitempty"` +} + // TaskProgressPayload is sent from daemon to server during task execution. type TaskProgressPayload struct { TaskID string `json:"task_id"` @@ -37,10 +44,10 @@ type TaskMessagePayload struct { IssueID string `json:"issue_id,omitempty"` Seq int `json:"seq"` Type string `json:"type"` // "text", "tool_use", "tool_result", "error" - Tool string `json:"tool,omitempty"` // tool name for tool_use/tool_result - Content string `json:"content,omitempty"` // text content - Input map[string]any `json:"input,omitempty"` // tool input (tool_use only) - Output string `json:"output,omitempty"` // tool output (tool_result only) + Tool string `json:"tool,omitempty"` // tool name for tool_use/tool_result + Content string `json:"content,omitempty"` // text content + Input map[string]any `json:"input,omitempty"` // tool input (tool_use only) + Output string `json:"output,omitempty"` // tool output (tool_result only) } // DaemonRegisterPayload is sent from daemon to server on connection.