feat: add daemon websocket task wakeups (#1772)

* feat: add daemon websocket task wakeups

* feat: fan out daemon wakeups across nodes

* fix: dedupe daemon wakeup loopback events

* fix: lengthen daemon polling fallback interval

---------

Co-authored-by: Eve <eve@multica.ai>
This commit is contained in:
devv-eve
2026-04-28 16:07:24 +08:00
committed by GitHub
parent 541aaa974d
commit 9db91e89f5
24 changed files with 1202 additions and 28 deletions

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"strings"
"github.com/multica-ai/multica/server/internal/daemonws"
"github.com/multica-ai/multica/server/internal/realtime"
)
@@ -47,7 +48,9 @@ func realtimeMetricsHandler(token string) http.HandlerFunc {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
_ = json.NewEncoder(w).Encode(realtime.M.Snapshot())
snapshot := realtime.M.Snapshot()
snapshot["daemonws"] = daemonws.M.Snapshot()
_ = json.NewEncoder(w).Encode(snapshot)
}
}

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/multica-ai/multica/server/internal/analytics"
"github.com/multica-ai/multica/server/internal/daemonws"
"github.com/multica-ai/multica/server/internal/events"
"github.com/multica-ai/multica/server/internal/logger"
obsmetrics "github.com/multica-ai/multica/server/internal/metrics"
@@ -161,6 +162,8 @@ func main() {
bus := events.New()
hub := realtime.NewHub()
go hub.Run()
daemonHub := daemonws.NewHub()
var daemonWakeup service.TaskWakeupNotifier = daemonHub
// MUL-1138: when REDIS_URL is set, route fanout through a Redis relay so
// multiple API nodes can deliver each other's events. Without it the hub
@@ -204,15 +207,21 @@ func main() {
case "legacy":
relayReadRedis = newNamedRedisClient(opts, "realtime-read")
relay = realtime.NewRedisRelayWithClients(hub, relayWriteRedis, relayReadRedis)
slog.Info("daemon websocket wakeup: Redis fanout disabled in legacy realtime relay mode")
case "dual":
shardedReadRedis = newNamedRedisClient(opts, "realtime-read-sharded")
legacyReadRedis = newNamedRedisClient(opts, "realtime-read-legacy")
sharded := realtime.NewShardedStreamRelay(hub, relayWriteRedis, shardedReadRedis, relayConfig)
sharded.SetDaemonRuntimeDeliverer(daemonHub)
legacy := realtime.NewRedisRelayWithClients(hub, relayWriteRedis, legacyReadRedis)
relay = realtime.NewMirroredRelay(sharded, legacy)
daemonWakeup = daemonws.NewRelayNotifier(daemonHub, sharded)
default:
relayReadRedis = newNamedRedisClient(opts, "realtime-read")
relay = realtime.NewShardedStreamRelay(hub, relayWriteRedis, relayReadRedis, relayConfig)
sharded := realtime.NewShardedStreamRelay(hub, relayWriteRedis, relayReadRedis, relayConfig)
sharded.SetDaemonRuntimeDeliverer(daemonHub)
relay = sharded
daemonWakeup = daemonws.NewRelayNotifier(daemonHub, sharded)
}
relay.Start(relayCtx)
broadcaster = realtime.NewDualWriteBroadcaster(hub, relay)
@@ -253,6 +262,7 @@ func main() {
metricsRegistry := obsmetrics.NewRegistry(obsmetrics.RegistryOptions{
Pool: pool,
Realtime: realtime.M,
DaemonWS: daemonws.M,
Version: version,
Commit: commit,
})
@@ -267,7 +277,9 @@ func main() {
}
r := NewRouterWithOptions(pool, hub, bus, analyticsClient, storeRedis, RouterOptions{
HTTPMetrics: httpMetrics,
HTTPMetrics: httpMetrics,
DaemonHub: daemonHub,
DaemonWakeup: daemonWakeup,
})
srv := &http.Server{
@@ -278,7 +290,7 @@ func main() {
// Start background workers.
sweepCtx, sweepCancel := context.WithCancel(context.Background())
autopilotCtx, autopilotCancel := context.WithCancel(context.Background())
taskSvc := service.NewTaskService(queries, pool, hub, bus)
taskSvc := service.NewTaskService(queries, pool, hub, bus, daemonWakeup)
autopilotSvc := service.NewAutopilotService(queries, pool, bus, taskSvc)
registerAutopilotListeners(bus, autopilotSvc)

View File

@@ -15,6 +15,7 @@ import (
"github.com/multica-ai/multica/server/internal/analytics"
"github.com/multica-ai/multica/server/internal/auth"
"github.com/multica-ai/multica/server/internal/daemonws"
"github.com/multica-ai/multica/server/internal/events"
"github.com/multica-ai/multica/server/internal/handler"
obsmetrics "github.com/multica-ai/multica/server/internal/metrics"
@@ -67,12 +68,18 @@ func NewRouter(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus, analytics
}
type RouterOptions struct {
HTTPMetrics *obsmetrics.HTTPMetrics
HTTPMetrics *obsmetrics.HTTPMetrics
DaemonHub *daemonws.Hub
DaemonWakeup service.TaskWakeupNotifier
}
func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus, analyticsClient analytics.Client, rdb *redis.Client, opts RouterOptions) chi.Router {
queries := db.New(pool)
emailSvc := service.NewEmailService()
daemonHub := opts.DaemonHub
if daemonHub == nil {
daemonHub = daemonws.NewHub()
}
// Initialize storage with S3 as primary, fallback to local
var store storage.Storage
@@ -93,7 +100,10 @@ func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus
AllowedEmails: splitAndTrim(os.Getenv("ALLOWED_EMAILS")),
AllowedEmailDomains: splitAndTrim(os.Getenv("ALLOWED_EMAIL_DOMAINS")),
}
h := handler.New(queries, pool, hub, bus, emailSvc, store, cfSigner, analyticsClient, signupConfig)
h := handler.New(queries, pool, hub, bus, emailSvc, store, cfSigner, analyticsClient, signupConfig, daemonHub)
if opts.DaemonWakeup != nil {
h.TaskService.Wakeup = opts.DaemonWakeup
}
if rdb != nil {
h.LocalSkillListStore = handler.NewRedisLocalSkillListStore(rdb)
h.LocalSkillImportStore = handler.NewRedisLocalSkillImportStore(rdb)
@@ -178,6 +188,7 @@ func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus
r.Post("/register", h.DaemonRegister)
r.Post("/deregister", h.DaemonDeregister)
r.Post("/heartbeat", h.DaemonHeartbeat)
r.Get("/ws", h.DaemonWebSocket)
r.Get("/workspaces/{workspaceId}/repos", h.GetDaemonWorkspaceRepos)
r.Post("/runtimes/{runtimeId}/tasks/claim", h.ClaimTaskByRuntime)

View File

@@ -12,7 +12,7 @@ import (
const (
DefaultServerURL = "ws://localhost:8080/ws"
DefaultPollInterval = 3 * time.Second
DefaultPollInterval = 30 * time.Second
DefaultHeartbeatInterval = 15 * time.Second
DefaultAgentTimeout = 2 * time.Hour
DefaultCodexSemanticInactivityTimeout = 10 * time.Minute

View File

@@ -45,6 +45,7 @@ type Daemon struct {
workspaces map[string]*workspaceState
runtimeIndex map[string]Runtime // runtimeID -> Runtime for provider lookups
reloading sync.Mutex // prevents concurrent workspace syncs
runtimeSetCh chan struct{} // notifies the WS wakeup loop to reconnect with a new runtime set
versionsMu sync.RWMutex // guards agentVersions
agentVersions map[string]string // provider -> detected CLI version (set during registration)
@@ -69,6 +70,7 @@ func New(cfg Config, logger *slog.Logger) *Daemon {
logger: logger,
workspaces: make(map[string]*workspaceState),
runtimeIndex: make(map[string]Runtime),
runtimeSetCh: make(chan struct{}, 1),
agentVersions: make(map[string]string),
}
}
@@ -89,6 +91,23 @@ func (d *Daemon) agentVersion(provider string) string {
return d.agentVersions[provider]
}
func (d *Daemon) notifyRuntimeSetChanged() {
select {
case d.runtimeSetCh <- struct{}{}:
default:
}
}
func (d *Daemon) drainRuntimeSetChanged() {
for {
select {
case <-d.runtimeSetCh:
default:
return
}
}
}
// Run starts the daemon: resolves auth, registers runtimes, then polls for tasks.
func (d *Daemon) Run(ctx context.Context) error {
// Wrap context so handleUpdate can cancel the daemon for restart.
@@ -132,10 +151,13 @@ func (d *Daemon) Run(ctx context.Context) error {
// Start workspace sync loop to discover newly created workspaces.
go d.workspaceSyncLoop(ctx)
taskWakeups := make(chan struct{}, 1)
d.drainRuntimeSetChanged()
go d.taskWakeupLoop(ctx, taskWakeups)
go d.heartbeatLoop(ctx)
go d.gcLoop(ctx)
go d.serveHealth(ctx, healthLn, time.Now())
return d.pollLoop(ctx)
return d.pollLoop(ctx, taskWakeups)
}
// RestartBinary returns the path to the new binary if the daemon needs to restart
@@ -422,6 +444,7 @@ func (d *Daemon) syncWorkspacesFromAPI(ctx context.Context) error {
d.mu.Unlock()
var registered int
var removed int
for id, name := range apiIDs {
if currentIDs[id] {
continue // important: never replace existing workspaceState; ensureRepoReady holds ws.repoRefreshMu from the original pointer
@@ -473,8 +496,12 @@ func (d *Daemon) syncWorkspacesFromAPI(ctx context.Context) error {
delete(d.workspaces, id)
d.mu.Unlock()
d.logger.Info("stopped watching workspace", "workspace_id", id)
removed++
}
}
if registered > 0 || removed > 0 {
d.notifyRuntimeSetChanged()
}
if len(d.allRuntimeIDs()) == 0 && registered == 0 && len(workspaces) > 0 {
return fmt.Errorf("failed to register runtimes for any of the %d workspace(s)", len(workspaces))
@@ -799,7 +826,7 @@ func (d *Daemon) triggerRestart() {
}
}
func (d *Daemon) pollLoop(ctx context.Context) error {
func (d *Daemon) pollLoop(ctx context.Context, taskWakeups <-chan struct{}) error {
sem := make(chan struct{}, d.cfg.MaxConcurrentTasks)
var wg sync.WaitGroup
@@ -822,7 +849,7 @@ func (d *Daemon) pollLoop(ctx context.Context) error {
runtimeIDs := d.allRuntimeIDs()
if len(runtimeIDs) == 0 {
if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil {
if err := sleepWithContextOrWakeup(ctx, d.cfg.PollInterval, taskWakeups); err != nil {
wg.Wait()
return err
}
@@ -878,7 +905,7 @@ func (d *Daemon) pollLoop(ctx context.Context) error {
d.logger.Debug("poll: no tasks", "runtimes", runtimeIDs, "cycle", pollCount)
}
pollOffset = (pollOffset + 1) % n
if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil {
if err := sleepWithContextOrWakeup(ctx, d.cfg.PollInterval, taskWakeups); err != nil {
wg.Wait()
return err
}

View File

@@ -78,3 +78,21 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
return nil
}
}
func sleepWithContextOrWakeup(ctx context.Context, d time.Duration, wakeups <-chan struct{}) error {
if wakeups == nil {
return sleepWithContext(ctx, d)
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-wakeups:
return nil
case <-timer.C:
return nil
}
}

View File

@@ -0,0 +1,188 @@
package daemon
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/rand"
"net/http"
"net/url"
"sort"
"strings"
"time"
"github.com/gorilla/websocket"
"github.com/multica-ai/multica/server/pkg/protocol"
)
var errRuntimeSetChanged = errors.New("runtime set changed")
func (d *Daemon) taskWakeupLoop(ctx context.Context, taskWakeups chan<- struct{}) {
backoff := time.Second
for {
runtimeIDs := d.allRuntimeIDs()
if len(runtimeIDs) == 0 {
if err := sleepWithContextOrRuntimeChange(ctx, 5*time.Second, d.runtimeSetCh); err != nil {
return
}
continue
}
err := d.runTaskWakeupConnection(ctx, runtimeIDs, taskWakeups)
if ctx.Err() != nil {
return
}
if errors.Is(err, errRuntimeSetChanged) {
backoff = time.Second
continue
}
if err != nil {
d.logger.Debug("task wakeup websocket unavailable; polling fallback remains active", "error", err, "retry_in", backoff)
}
if err := sleepWithContextOrRuntimeChange(ctx, jitterDuration(backoff), d.runtimeSetCh); err != nil {
return
}
if backoff < 30*time.Second {
backoff *= 2
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
}
}
}
func jitterDuration(d time.Duration) time.Duration {
if d <= 0 {
return d
}
spread := d / 5
if spread <= 0 {
return d
}
delta := time.Duration(rand.Int63n(int64(spread)*2+1)) - spread
return d + delta
}
func (d *Daemon) runTaskWakeupConnection(ctx context.Context, runtimeIDs []string, taskWakeups chan<- struct{}) error {
wsURL, err := taskWakeupURL(d.cfg.ServerBaseURL, runtimeIDs)
if err != nil {
return err
}
headers := http.Header{}
if token := d.client.Token(); token != "" {
headers.Set("Authorization", "Bearer "+token)
}
if d.client.platform != "" {
headers.Set("X-Client-Platform", d.client.platform)
}
if d.client.version != "" {
headers.Set("X-Client-Version", d.client.version)
}
if d.client.os != "" {
headers.Set("X-Client-OS", d.client.os)
}
dialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second}
conn, _, err := dialer.DialContext(ctx, wsURL, headers)
if err != nil {
return err
}
defer conn.Close()
d.logger.Info("task wakeup websocket connected", "runtimes", len(runtimeIDs))
signalTaskWakeup(taskWakeups)
errCh := make(chan error, 1)
go func() {
errCh <- d.readTaskWakeupMessages(conn, taskWakeups)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-d.runtimeSetCh:
return errRuntimeSetChanged
case err := <-errCh:
return err
}
}
func (d *Daemon) readTaskWakeupMessages(conn *websocket.Conn, taskWakeups chan<- struct{}) error {
conn.SetReadLimit(64 * 1024)
for {
_, raw, err := conn.ReadMessage()
if err != nil {
return err
}
var msg protocol.Message
if err := json.Unmarshal(raw, &msg); err != nil {
d.logger.Debug("task wakeup websocket invalid message", "error", err)
continue
}
if msg.Type != protocol.EventDaemonTaskAvailable {
continue
}
var payload protocol.TaskAvailablePayload
if len(msg.Payload) > 0 {
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
d.logger.Debug("task wakeup websocket invalid payload", "error", err)
continue
}
}
if payload.RuntimeID != "" {
d.logger.Debug("task wakeup received", "runtime_id", payload.RuntimeID, "task_id", payload.TaskID)
}
signalTaskWakeup(taskWakeups)
}
}
func signalTaskWakeup(taskWakeups chan<- struct{}) {
select {
case taskWakeups <- struct{}{}:
default:
}
}
func taskWakeupURL(baseURL string, runtimeIDs []string) (string, error) {
u, err := url.Parse(strings.TrimSpace(baseURL))
if err != nil {
return "", fmt.Errorf("invalid daemon server URL: %w", err)
}
switch u.Scheme {
case "http":
u.Scheme = "ws"
case "https":
u.Scheme = "wss"
case "ws", "wss":
default:
return "", fmt.Errorf("daemon server URL must use http, https, ws, or wss")
}
u.Path = strings.TrimRight(u.Path, "/") + "/api/daemon/ws"
u.RawPath = ""
q := u.Query()
ids := append([]string(nil), runtimeIDs...)
sort.Strings(ids)
q.Set("runtime_ids", strings.Join(ids, ","))
u.RawQuery = q.Encode()
u.Fragment = ""
return u.String(), nil
}
func sleepWithContextOrRuntimeChange(ctx context.Context, d time.Duration, runtimeSetCh <-chan struct{}) error {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-runtimeSetCh:
return nil
case <-timer.C:
return nil
}
}

View File

@@ -0,0 +1,43 @@
package daemon
import "testing"
func TestTaskWakeupURL(t *testing.T) {
tests := []struct {
name string
baseURL string
runtimeIDs []string
want string
}{
{
name: "http base",
baseURL: "http://localhost:8080",
runtimeIDs: []string{"runtime-b", "runtime-a"},
want: "ws://localhost:8080/api/daemon/ws?runtime_ids=runtime-a%2Cruntime-b",
},
{
name: "https base",
baseURL: "https://api.example.com",
runtimeIDs: []string{"runtime-1"},
want: "wss://api.example.com/api/daemon/ws?runtime_ids=runtime-1",
},
{
name: "base path",
baseURL: "https://api.example.com/multica",
runtimeIDs: []string{"runtime-1"},
want: "wss://api.example.com/multica/api/daemon/ws?runtime_ids=runtime-1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := taskWakeupURL(tt.baseURL, tt.runtimeIDs)
if err != nil {
t.Fatalf("taskWakeupURL: %v", err)
}
if got != tt.want {
t.Fatalf("taskWakeupURL() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,349 @@
package daemonws
import (
"encoding/json"
"log/slog"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/multica-ai/multica/server/pkg/protocol"
)
const (
writeWait = 10 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
)
// ClientIdentity captures the already-authenticated daemon connection scope.
type ClientIdentity struct {
DaemonID string
UserID string
WorkspaceID string
RuntimeIDs []string
ClientVersion string
}
type client struct {
hub *Hub
conn *websocket.Conn
send chan []byte
identity ClientIdentity
runtimes map[string]struct{}
dedupMu sync.Mutex
seenIDs map[string]struct{}
seenList []string
}
const eventDedupCapacity = 128
// markSeen records eventID as already delivered to this client. Empty event IDs
// disable dedup and are always delivered.
func (c *client) markSeen(eventID string) bool {
if eventID == "" {
return true
}
c.dedupMu.Lock()
defer c.dedupMu.Unlock()
if c.seenIDs == nil {
c.seenIDs = make(map[string]struct{}, eventDedupCapacity)
}
if _, ok := c.seenIDs[eventID]; ok {
return false
}
c.seenIDs[eventID] = struct{}{}
c.seenList = append(c.seenList, eventID)
if len(c.seenList) > eventDedupCapacity {
drop := c.seenList[0]
c.seenList = c.seenList[1:]
delete(c.seenIDs, drop)
}
return true
}
// Hub keeps daemon WebSocket connections indexed by runtime ID. Messages are
// best-effort wakeup hints; the daemon still uses HTTP claim for correctness.
type Hub struct {
upgrader websocket.Upgrader
mu sync.RWMutex
clients map[*client]bool
byRuntime map[string]map[*client]bool
}
func NewHub() *Hub {
return &Hub{
upgrader: websocket.Upgrader{
// Daemon clients authenticate with Authorization headers before the
// upgrade. Browsers cannot set those headers through the native WS API,
// and DaemonAuth does not accept cookies, so cookie-based CSWSH does
// not apply to this endpoint. Re-evaluate this if DaemonAuth ever
// grows cookie fallback.
CheckOrigin: func(r *http.Request) bool { return true },
},
clients: make(map[*client]bool),
byRuntime: make(map[string]map[*client]bool),
}
}
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request, identity ClientIdentity) {
if len(identity.RuntimeIDs) == 0 {
http.Error(w, `{"error":"runtime_ids required"}`, http.StatusBadRequest)
return
}
conn, err := h.upgrader.Upgrade(w, r, nil)
if err != nil {
slog.Error("daemon websocket upgrade failed", "error", err)
return
}
runtimes := make(map[string]struct{}, len(identity.RuntimeIDs))
for _, runtimeID := range identity.RuntimeIDs {
if runtimeID != "" {
runtimes[runtimeID] = struct{}{}
}
}
if len(runtimes) == 0 {
conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"runtime_ids required"}`))
conn.Close()
return
}
c := &client{
hub: h,
conn: conn,
send: make(chan []byte, 16),
identity: identity,
runtimes: runtimes,
}
h.register(c)
go c.writePump()
go c.readPump()
}
// NotifyTaskAvailable sends a best-effort wakeup to daemons watching runtimeID.
func (h *Hub) NotifyTaskAvailable(runtimeID, taskID string) {
h.notifyTaskAvailable(runtimeID, taskID, "")
}
func (h *Hub) notifyTaskAvailable(runtimeID, taskID, eventID string) {
if h == nil || runtimeID == "" {
return
}
data, err := taskAvailableFrame(runtimeID, taskID)
if err != nil {
return
}
delivered, deduped := h.notifyFrame(runtimeID, data, eventID)
if delivered {
M.WakeupDeliveredHit.Add(1)
} else if !deduped {
M.WakeupDeliveredMiss.Add(1)
}
}
func (h *Hub) DeliverDaemonRuntime(scopeID string, frame []byte, eventID string) {
if h == nil {
return
}
M.WakeupReceivedTotal.Add(1)
var msg protocol.Message
if err := json.Unmarshal(frame, &msg); err != nil {
slog.Debug("daemon websocket relay: invalid frame", "error", err, "scope_id", scopeID, "event_id", eventID)
M.WakeupDeliveredMiss.Add(1)
return
}
if msg.Type != protocol.EventDaemonTaskAvailable {
M.WakeupDeliveredMiss.Add(1)
return
}
var payload protocol.TaskAvailablePayload
if err := json.Unmarshal(msg.Payload, &payload); err != nil || payload.RuntimeID == "" {
slog.Debug("daemon websocket relay: invalid task_available payload", "error", err, "scope_id", scopeID, "event_id", eventID)
M.WakeupDeliveredMiss.Add(1)
return
}
delivered, deduped := h.notifyFrame(payload.RuntimeID, frame, eventID)
if delivered {
M.WakeupDeliveredHit.Add(1)
} else if !deduped {
M.WakeupDeliveredMiss.Add(1)
}
}
func (h *Hub) notifyFrame(runtimeID string, data []byte, eventID string) (delivered bool, deduped bool) {
h.mu.RLock()
clients := h.byRuntime[runtimeID]
slow := make([]*client, 0)
for c := range clients {
if !c.markSeen(eventID) {
deduped = true
continue
}
select {
case c.send <- data:
delivered = true
default:
slow = append(slow, c)
}
}
h.mu.RUnlock()
for _, c := range slow {
h.unregister(c)
c.conn.Close()
}
if len(slow) > 0 {
M.SlowEvictionsTotal.Add(int64(len(slow)))
}
return delivered, deduped
}
func taskAvailableFrame(runtimeID, taskID string) ([]byte, error) {
return json.Marshal(protocol.Message{
Type: protocol.EventDaemonTaskAvailable,
Payload: mustMarshalRaw(protocol.TaskAvailablePayload{
RuntimeID: runtimeID,
TaskID: taskID,
}),
})
}
func mustMarshalRaw(v any) json.RawMessage {
data, err := json.Marshal(v)
if err != nil {
return nil
}
return data
}
func (h *Hub) RuntimeConnectionCount(runtimeID string) int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.byRuntime[runtimeID])
}
func (h *Hub) register(c *client) {
h.mu.Lock()
h.clients[c] = true
for runtimeID := range c.runtimes {
conns := h.byRuntime[runtimeID]
if conns == nil {
conns = make(map[*client]bool)
h.byRuntime[runtimeID] = conns
}
conns[c] = true
}
total := len(h.clients)
h.mu.Unlock()
M.ConnectsTotal.Add(1)
M.ActiveConnections.Add(1)
slog.Info("daemon websocket connected",
"daemon_id", c.identity.DaemonID,
"user_id", c.identity.UserID,
"workspace_id", c.identity.WorkspaceID,
"runtimes", len(c.runtimes),
"client_version", c.identity.ClientVersion,
"total_clients", total,
)
}
func (h *Hub) unregister(c *client) {
h.mu.Lock()
if !h.clients[c] {
h.mu.Unlock()
return
}
delete(h.clients, c)
for runtimeID := range c.runtimes {
if conns := h.byRuntime[runtimeID]; conns != nil {
delete(conns, c)
if len(conns) == 0 {
delete(h.byRuntime, runtimeID)
}
}
}
close(c.send)
total := len(h.clients)
h.mu.Unlock()
M.DisconnectsTotal.Add(1)
M.ActiveConnections.Add(-1)
slog.Info("daemon websocket disconnected",
"daemon_id", c.identity.DaemonID,
"user_id", c.identity.UserID,
"workspace_id", c.identity.WorkspaceID,
"runtimes", len(c.runtimes),
"total_clients", total,
)
}
func (c *client) readPump() {
defer func() {
c.hub.unregister(c)
c.conn.Close()
}()
c.conn.SetReadLimit(4096)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, raw, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
slog.Debug("daemon websocket read error", "error", err, "daemon_id", c.identity.DaemonID)
}
return
}
c.handleFrame(raw)
}
}
func (c *client) handleFrame(raw []byte) {
var msg protocol.Message
if err := json.Unmarshal(raw, &msg); err != nil {
slog.Debug("daemon websocket invalid frame", "error", err, "daemon_id", c.identity.DaemonID)
return
}
// The phase-one daemon channel is server-push only. Inbound frames are
// drained so control frames and close handling work, but app messages are
// intentionally ignored for forward compatibility.
}
func (c *client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
slog.Debug("daemon websocket write error", "error", err, "daemon_id", c.identity.DaemonID)
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}

View File

@@ -0,0 +1,200 @@
package daemonws
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/multica-ai/multica/server/internal/realtime"
"github.com/multica-ai/multica/server/pkg/protocol"
)
func TestNotifyTaskAvailable(t *testing.T) {
M.Reset()
defer M.Reset()
hub := NewHub()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hub.HandleWebSocket(w, r, ClientIdentity{RuntimeIDs: []string{"runtime-1"}})
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
deadline := time.Now().Add(time.Second)
for hub.RuntimeConnectionCount("runtime-1") == 0 {
if time.Now().After(deadline) {
t.Fatal("runtime connection was not registered")
}
time.Sleep(10 * time.Millisecond)
}
hub.NotifyTaskAvailable("runtime-1", "task-1")
if err := conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatalf("SetReadDeadline: %v", err)
}
_, raw, err := conn.ReadMessage()
if err != nil {
t.Fatalf("ReadMessage: %v", err)
}
var msg protocol.Message
if err := json.Unmarshal(raw, &msg); err != nil {
t.Fatalf("unmarshal message: %v", err)
}
if msg.Type != protocol.EventDaemonTaskAvailable {
t.Fatalf("message type = %q, want %q", msg.Type, protocol.EventDaemonTaskAvailable)
}
var payload protocol.TaskAvailablePayload
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
t.Fatalf("unmarshal payload: %v", err)
}
if payload.RuntimeID != "runtime-1" || payload.TaskID != "task-1" {
t.Fatalf("payload = %+v, want runtime/task IDs", payload)
}
}
func TestRelayNotifierPublishesDaemonRuntimeScope(t *testing.T) {
M.Reset()
defer M.Reset()
relay := &recordingRelayPublisher{}
notifier := NewRelayNotifier(nil, relay)
notifier.NotifyTaskAvailable("runtime-1", "task-1")
if relay.scopeType != realtime.ScopeDaemonRuntime {
t.Fatalf("scopeType = %q, want %q", relay.scopeType, realtime.ScopeDaemonRuntime)
}
if relay.scopeID != "task-1" {
t.Fatalf("scopeID = %q, want task_id shard key", relay.scopeID)
}
if relay.eventID == "" {
t.Fatal("expected event id")
}
if M.WakeupPublishedTotal.Load() != 1 {
t.Fatalf("published metric = %d, want 1", M.WakeupPublishedTotal.Load())
}
var msg protocol.Message
if err := json.Unmarshal(relay.frame, &msg); err != nil {
t.Fatalf("unmarshal frame: %v", err)
}
if msg.Type != protocol.EventDaemonTaskAvailable {
t.Fatalf("message type = %q, want %q", msg.Type, protocol.EventDaemonTaskAvailable)
}
var payload protocol.TaskAvailablePayload
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
t.Fatalf("unmarshal payload: %v", err)
}
if payload.RuntimeID != "runtime-1" || payload.TaskID != "task-1" {
t.Fatalf("payload = %+v, want runtime/task IDs", payload)
}
}
func TestRelayNotifierDedupsLocalRedisLoopback(t *testing.T) {
M.Reset()
defer M.Reset()
hub := NewHub()
client := attachDaemonTestClient(hub, "runtime-1")
relay := &localFirstDaemonRelayPublisher{t: t, client: client}
notifier := NewRelayNotifier(hub, relay)
notifier.NotifyTaskAvailable("runtime-1", "task-1")
if !relay.called {
t.Fatal("expected relay publish to be invoked")
}
if relay.eventID == "" {
t.Fatal("expected event id")
}
if M.WakeupDeliveredHit.Load() != 1 {
t.Fatalf("delivered hit metric = %d, want 1", M.WakeupDeliveredHit.Load())
}
hub.DeliverDaemonRuntime(relay.scopeID, relay.frame, relay.eventID)
select {
case duplicate := <-client.send:
t.Fatalf("expected redis loopback to be deduped, got duplicate %s", duplicate)
case <-time.After(20 * time.Millisecond):
}
if M.WakeupDeliveredHit.Load() != 1 {
t.Fatalf("delivered hit metric after loopback = %d, want 1", M.WakeupDeliveredHit.Load())
}
if M.WakeupDeliveredMiss.Load() != 0 {
t.Fatalf("delivered miss metric after dedup = %d, want 0", M.WakeupDeliveredMiss.Load())
}
}
func attachDaemonTestClient(hub *Hub, runtimeID string) *client {
c := &client{
send: make(chan []byte, 2),
runtimes: map[string]struct{}{runtimeID: {}},
}
hub.mu.Lock()
hub.clients[c] = true
hub.byRuntime[runtimeID] = map[*client]bool{c: true}
hub.mu.Unlock()
return c
}
type recordingRelayPublisher struct {
scopeType string
scopeID string
exclude string
frame []byte
eventID string
}
func (r *recordingRelayPublisher) PublishWithID(scopeType, scopeID, exclude string, frame []byte, id string) error {
r.scopeType = scopeType
r.scopeID = scopeID
r.exclude = exclude
r.frame = append([]byte(nil), frame...)
r.eventID = id
return nil
}
type localFirstDaemonRelayPublisher struct {
t *testing.T
client *client
called bool
scopeType string
scopeID string
exclude string
frame []byte
eventID string
localFrame []byte
}
func (p *localFirstDaemonRelayPublisher) PublishWithID(scopeType, scopeID, exclude string, frame []byte, id string) error {
p.called = true
p.scopeType = scopeType
p.scopeID = scopeID
p.exclude = exclude
p.frame = append([]byte(nil), frame...)
p.eventID = id
select {
case p.localFrame = <-p.client.send:
default:
p.t.Fatal("expected local fanout to happen before relay publish")
}
return nil
}

View File

@@ -0,0 +1,44 @@
package daemonws
import "sync/atomic"
type Metrics struct {
ConnectsTotal atomic.Int64
DisconnectsTotal atomic.Int64
ActiveConnections atomic.Int64
SlowEvictionsTotal atomic.Int64
WakeupPublishedTotal atomic.Int64
WakeupPublishErrors atomic.Int64
WakeupReceivedTotal atomic.Int64
WakeupDeliveredHit atomic.Int64
WakeupDeliveredMiss atomic.Int64
}
var M = &Metrics{}
func (m *Metrics) Snapshot() map[string]any {
return map[string]any{
"connects_total": m.ConnectsTotal.Load(),
"disconnects_total": m.DisconnectsTotal.Load(),
"active_connections": m.ActiveConnections.Load(),
"slow_evictions_total": m.SlowEvictionsTotal.Load(),
"wakeup_published_total": m.WakeupPublishedTotal.Load(),
"wakeup_publish_errors": m.WakeupPublishErrors.Load(),
"wakeup_received_total": m.WakeupReceivedTotal.Load(),
"wakeup_delivered_hit_total": m.WakeupDeliveredHit.Load(),
"wakeup_delivered_miss_total": m.WakeupDeliveredMiss.Load(),
}
}
func (m *Metrics) Reset() {
m.ConnectsTotal.Store(0)
m.DisconnectsTotal.Store(0)
m.ActiveConnections.Store(0)
m.SlowEvictionsTotal.Store(0)
m.WakeupPublishedTotal.Store(0)
m.WakeupPublishErrors.Store(0)
m.WakeupReceivedTotal.Store(0)
m.WakeupDeliveredHit.Store(0)
m.WakeupDeliveredMiss.Store(0)
}

View File

@@ -0,0 +1,49 @@
package daemonws
import (
"log/slog"
"github.com/oklog/ulid/v2"
"github.com/multica-ai/multica/server/internal/realtime"
)
// RelayNotifier sends task wakeups to the local daemon hub and, when Redis is
// configured, publishes the same wakeup through the shared realtime relay so
// every API node can attempt local delivery.
type RelayNotifier struct {
local *Hub
relay realtime.RelayPublisher
}
func NewRelayNotifier(local *Hub, relay realtime.RelayPublisher) *RelayNotifier {
return &RelayNotifier{local: local, relay: relay}
}
func (n *RelayNotifier) NotifyTaskAvailable(runtimeID, taskID string) {
if runtimeID == "" {
return
}
eventID := ulid.Make().String()
if n.local != nil {
n.local.notifyTaskAvailable(runtimeID, taskID, eventID)
}
if n.relay == nil {
return
}
frame, err := taskAvailableFrame(runtimeID, taskID)
if err != nil {
M.WakeupPublishErrors.Add(1)
return
}
shardKey := taskID
if shardKey == "" {
shardKey = eventID
}
if err := n.relay.PublishWithID(realtime.ScopeDaemonRuntime, shardKey, "", frame, eventID); err != nil {
M.WakeupPublishErrors.Add(1)
slog.Warn("daemon websocket wakeup publish failed", "error", err, "runtime_id", runtimeID, "task_id", taskID)
return
}
M.WakeupPublishedTotal.Add(1)
}

View File

@@ -0,0 +1,66 @@
package handler
import (
"net/http"
"strings"
"github.com/multica-ai/multica/server/internal/daemonws"
"github.com/multica-ai/multica/server/internal/middleware"
)
func (h *Handler) DaemonWebSocket(w http.ResponseWriter, r *http.Request) {
if h.DaemonHub == nil {
writeError(w, http.StatusServiceUnavailable, "daemon websocket unavailable")
return
}
runtimeIDs := parseRuntimeIDs(r)
if len(runtimeIDs) == 0 {
writeError(w, http.StatusBadRequest, "runtime_ids required")
return
}
for _, runtimeID := range runtimeIDs {
rt, ok := h.requireDaemonRuntimeAccess(w, r, runtimeID)
if !ok {
return
}
if daemonID := middleware.DaemonIDFromContext(r.Context()); daemonID != "" && rt.DaemonID.Valid && rt.DaemonID.String != daemonID {
writeError(w, http.StatusNotFound, "runtime not found")
return
}
}
h.DaemonHub.HandleWebSocket(w, r, daemonws.ClientIdentity{
DaemonID: middleware.DaemonIDFromContext(r.Context()),
UserID: requestUserID(r),
WorkspaceID: middleware.DaemonWorkspaceIDFromContext(r.Context()),
RuntimeIDs: runtimeIDs,
ClientVersion: r.Header.Get("X-Client-Version"),
})
}
func parseRuntimeIDs(r *http.Request) []string {
seen := map[string]struct{}{}
var out []string
add := func(raw string) {
for _, part := range strings.Split(raw, ",") {
id := strings.TrimSpace(part)
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
}
for _, raw := range r.URL.Query()["runtime_id"] {
add(raw)
}
for _, raw := range r.URL.Query()["runtime_ids"] {
add(raw)
}
return out
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/jackc/pgx/v5/pgtype"
"github.com/multica-ai/multica/server/internal/analytics"
"github.com/multica-ai/multica/server/internal/auth"
"github.com/multica-ai/multica/server/internal/daemonws"
"github.com/multica-ai/multica/server/internal/events"
"github.com/multica-ai/multica/server/internal/middleware"
"github.com/multica-ai/multica/server/internal/realtime"
@@ -53,6 +54,7 @@ type Handler struct {
DB dbExecutor
TxStarter txStarter
Hub *realtime.Hub
DaemonHub *daemonws.Hub
Bus *events.Bus
TaskService *service.TaskService
AutopilotService *service.AutopilotService
@@ -67,7 +69,7 @@ type Handler struct {
cfg Config
}
func New(queries *db.Queries, txStarter txStarter, hub *realtime.Hub, bus *events.Bus, emailService *service.EmailService, store storage.Storage, cfSigner *auth.CloudFrontSigner, analyticsClient analytics.Client, cfg Config) *Handler {
func New(queries *db.Queries, txStarter txStarter, hub *realtime.Hub, bus *events.Bus, emailService *service.EmailService, store storage.Storage, cfSigner *auth.CloudFrontSigner, analyticsClient analytics.Client, cfg Config, daemonHubs ...*daemonws.Hub) *Handler {
var executor dbExecutor
if candidate, ok := txStarter.(dbExecutor); ok {
executor = candidate
@@ -77,12 +79,18 @@ func New(queries *db.Queries, txStarter txStarter, hub *realtime.Hub, bus *event
analyticsClient = analytics.NoopClient{}
}
taskSvc := service.NewTaskService(queries, txStarter, hub, bus)
var daemonHub *daemonws.Hub
if len(daemonHubs) > 0 {
daemonHub = daemonHubs[0]
}
taskSvc := service.NewTaskService(queries, txStarter, hub, bus, daemonHub)
return &Handler{
Queries: queries,
DB: executor,
TxStarter: txStarter,
Hub: hub,
DaemonHub: daemonHub,
Bus: bus,
TaskService: taskSvc,
AutopilotService: service.NewAutopilotService(queries, txStarter, bus, taskSvc),

View File

@@ -0,0 +1,70 @@
package metrics
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/multica-ai/multica/server/internal/daemonws"
)
type DaemonWSCollector struct {
metrics *daemonws.Metrics
connectsTotal *prometheus.Desc
disconnectsTotal *prometheus.Desc
activeConnections *prometheus.Desc
slowEvictionsTotal *prometheus.Desc
wakeupPublishedTotal *prometheus.Desc
wakeupPublishErrors *prometheus.Desc
wakeupReceivedTotal *prometheus.Desc
wakeupDeliveredTotal *prometheus.Desc
}
func NewDaemonWSCollector(m *daemonws.Metrics) *DaemonWSCollector {
return &DaemonWSCollector{
metrics: m,
connectsTotal: newDaemonWSDesc("connects_total", "Total daemon WebSocket connections opened."),
disconnectsTotal: newDaemonWSDesc("disconnects_total", "Total daemon WebSocket connections closed."),
activeConnections: newDaemonWSDesc("active_connections", "Current daemon WebSocket connections."),
slowEvictionsTotal: newDaemonWSDesc("slow_evictions_total", "Total daemon WebSocket clients evicted for slow consumption."),
wakeupPublishedTotal: newDaemonWSDesc("wakeup_published_total", "Total daemon wakeups published to the Redis relay."),
wakeupPublishErrors: newDaemonWSDesc("wakeup_publish_errors_total", "Total daemon wakeup Redis publish errors."),
wakeupReceivedTotal: newDaemonWSDesc("wakeup_received_total", "Total daemon wakeups received from the Redis relay."),
wakeupDeliveredTotal: prometheus.NewDesc("multica_daemonws_wakeup_delivered_total", "Total daemon wakeup local delivery attempts.", []string{"result"}, nil),
}
}
func newDaemonWSDesc(name, help string) *prometheus.Desc {
return prometheus.NewDesc("multica_daemonws_"+name, help, nil, nil)
}
func (c *DaemonWSCollector) Describe(ch chan<- *prometheus.Desc) {
for _, desc := range []*prometheus.Desc{
c.connectsTotal,
c.disconnectsTotal,
c.activeConnections,
c.slowEvictionsTotal,
c.wakeupPublishedTotal,
c.wakeupPublishErrors,
c.wakeupReceivedTotal,
c.wakeupDeliveredTotal,
} {
ch <- desc
}
}
func (c *DaemonWSCollector) Collect(ch chan<- prometheus.Metric) {
if c.metrics == nil {
return
}
m := c.metrics
ch <- prometheus.MustNewConstMetric(c.connectsTotal, prometheus.CounterValue, float64(m.ConnectsTotal.Load()))
ch <- prometheus.MustNewConstMetric(c.disconnectsTotal, prometheus.CounterValue, float64(m.DisconnectsTotal.Load()))
ch <- prometheus.MustNewConstMetric(c.activeConnections, prometheus.GaugeValue, float64(m.ActiveConnections.Load()))
ch <- prometheus.MustNewConstMetric(c.slowEvictionsTotal, prometheus.CounterValue, float64(m.SlowEvictionsTotal.Load()))
ch <- prometheus.MustNewConstMetric(c.wakeupPublishedTotal, prometheus.CounterValue, float64(m.WakeupPublishedTotal.Load()))
ch <- prometheus.MustNewConstMetric(c.wakeupPublishErrors, prometheus.CounterValue, float64(m.WakeupPublishErrors.Load()))
ch <- prometheus.MustNewConstMetric(c.wakeupReceivedTotal, prometheus.CounterValue, float64(m.WakeupReceivedTotal.Load()))
ch <- prometheus.MustNewConstMetric(c.wakeupDeliveredTotal, prometheus.CounterValue, float64(m.WakeupDeliveredHit.Load()), "hit")
ch <- prometheus.MustNewConstMetric(c.wakeupDeliveredTotal, prometheus.CounterValue, float64(m.WakeupDeliveredMiss.Load()), "miss")
}

View File

@@ -7,12 +7,14 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/multica-ai/multica/server/internal/daemonws"
"github.com/multica-ai/multica/server/internal/realtime"
)
type RegistryOptions struct {
Pool *pgxpool.Pool
Realtime *realtime.Metrics
DaemonWS *daemonws.Metrics
Version string
Commit string
}
@@ -43,6 +45,9 @@ func NewRegistry(opts RegistryOptions) *Registry {
if opts.Realtime != nil {
reg.MustRegister(NewRealtimeCollector(opts.Realtime))
}
if opts.DaemonWS != nil {
reg.MustRegister(NewDaemonWSCollector(opts.DaemonWS))
}
return &Registry{
Gatherer: reg,

View File

@@ -8,6 +8,9 @@ const (
ScopeUser = "user"
ScopeTask = "task"
ScopeChat = "chat"
// ScopeDaemonRuntime routes daemon wakeup frames through the Redis relay.
// It is consumed by the daemon WebSocket hub, not by browser clients.
ScopeDaemonRuntime = "daemon_runtime"
)
// Broadcaster is the abstraction every realtime event producer should depend
@@ -38,5 +41,10 @@ type Broadcaster interface {
Broadcast(message []byte)
}
// DaemonRuntimeDeliverer consumes daemon-runtime scoped relay frames.
type DaemonRuntimeDeliverer interface {
DeliverDaemonRuntime(scopeID string, frame []byte, eventID string)
}
// Compile-time assertion that *Hub continues to satisfy Broadcaster.
var _ Broadcaster = (*Hub)(nil)

View File

@@ -106,12 +106,16 @@ func redisString(v any) string {
}
}
func deliverEnvelope(hub *Hub, ev envelope) {
func deliverEnvelope(hub *Hub, daemonRuntime DaemonRuntimeDeliverer, ev envelope) {
if ev.PayloadJSON == "" {
return
}
frame := injectEventID([]byte(ev.PayloadJSON), ev.EventID)
switch ev.Scope {
case ScopeDaemonRuntime:
if daemonRuntime != nil {
daemonRuntime.DeliverDaemonRuntime(ev.ScopeID, frame, ev.EventID)
}
case "global":
hub.fanoutAllDedup(frame, "", ev.EventID)
case ScopeUser:
@@ -134,6 +138,8 @@ type RedisRelay struct {
consumers map[scopeKey]*scopeConsumer
stopping bool
wg sync.WaitGroup
daemonRuntime DaemonRuntimeDeliverer
}
type scopeConsumer struct {
@@ -167,6 +173,10 @@ func NewRedisRelayWithClients(hub *Hub, writeRDB, readRDB *redis.Client) *RedisR
// NodeID returns this relay's randomly-assigned node identifier.
func (r *RedisRelay) NodeID() string { return r.nodeID }
func (r *RedisRelay) SetDaemonRuntimeDeliverer(d DaemonRuntimeDeliverer) {
r.daemonRuntime = d
}
// Wait blocks until all relay-owned goroutines have exited after the Start
// context is canceled.
func (r *RedisRelay) Wait() {
@@ -394,7 +404,7 @@ func (r *RedisRelay) deliverMessage(scopeType, scopeID string, msg redis.XMessag
if ev.ScopeID == "" {
ev.ScopeID = scopeID
}
deliverEnvelope(r.hub, ev)
deliverEnvelope(r.hub, r.daemonRuntime, ev)
}
// fanoutUser is implemented in hub.go.

View File

@@ -36,6 +36,15 @@ func (r *MirroredRelay) NodeID() string {
return r.primary.NodeID()
}
func (r *MirroredRelay) SetDaemonRuntimeDeliverer(d DaemonRuntimeDeliverer) {
if setter, ok := r.primary.(interface{ SetDaemonRuntimeDeliverer(DaemonRuntimeDeliverer) }); ok {
setter.SetDaemonRuntimeDeliverer(d)
}
if setter, ok := r.mirror.(interface{ SetDaemonRuntimeDeliverer(DaemonRuntimeDeliverer) }); ok {
setter.SetDaemonRuntimeDeliverer(d)
}
}
func (r *MirroredRelay) Start(ctx context.Context) {
r.primary.Start(ctx)
r.mirror.Start(ctx)
@@ -74,6 +83,9 @@ func (r *MirroredRelay) Broadcast(message []byte) {
func (r *MirroredRelay) PublishWithID(scopeType, scopeID, exclude string, frame []byte, id string) error {
primaryErr := r.primary.PublishWithID(scopeType, scopeID, exclude, frame, id)
if scopeType == ScopeDaemonRuntime {
return primaryErr
}
mirrorErr := r.mirror.PublishWithID(scopeType, scopeID, exclude, frame, id)
if primaryErr != nil {

View File

@@ -50,6 +50,23 @@ func TestMirroredRelayRecordsDivergenceWhenOneBackendFails(t *testing.T) {
}
}
func TestMirroredRelayDoesNotMirrorDaemonRuntimeEvents(t *testing.T) {
primary := &recordingManagedRelay{nodeID: "primary"}
mirror := &recordingManagedRelay{nodeID: "mirror"}
relay := NewMirroredRelay(primary, mirror)
if err := relay.PublishWithID(ScopeDaemonRuntime, "task-1", "", []byte(`{"type":"daemon:task_available"}`), "event-1"); err != nil {
t.Fatalf("PublishWithID: %v", err)
}
if len(primary.calls) != 1 {
t.Fatalf("expected primary publish call, got %d", len(primary.calls))
}
if len(mirror.calls) != 0 {
t.Fatalf("expected daemon runtime event not to hit mirror, got %d calls", len(mirror.calls))
}
}
type relayPublishCall struct {
scopeType string
scopeID string

View File

@@ -76,6 +76,8 @@ type ShardedStreamRelay struct {
mu sync.Mutex
stopping bool
wg sync.WaitGroup
daemonRuntime DaemonRuntimeDeliverer
}
func NewShardedStreamRelay(hub *Hub, writeRDB, readRDB *redis.Client, config ShardedStreamRelayConfig) *ShardedStreamRelay {
@@ -93,6 +95,10 @@ func NewShardedStreamRelay(hub *Hub, writeRDB, readRDB *redis.Client, config Sha
func (r *ShardedStreamRelay) NodeID() string { return r.nodeID }
func (r *ShardedStreamRelay) SetDaemonRuntimeDeliverer(d DaemonRuntimeDeliverer) {
r.daemonRuntime = d
}
func (r *ShardedStreamRelay) Start(ctx context.Context) {
M.NodeID.Store(r.nodeID)
if err := r.writeRDB.Ping(ctx).Err(); err != nil {
@@ -233,7 +239,7 @@ func (r *ShardedStreamRelay) deliverMessage(msg redis.XMessage) {
if !ok || ev.Scope == "" || ev.ScopeID == "" {
return
}
deliverEnvelope(r.hub, ev)
deliverEnvelope(r.hub, r.daemonRuntime, ev)
}
func (r *ShardedStreamRelay) heartbeatLoop(ctx context.Context) {

View File

@@ -26,10 +26,19 @@ type TaskService struct {
TxStarter TxStarter
Hub *realtime.Hub
Bus *events.Bus
Wakeup TaskWakeupNotifier
}
func NewTaskService(q *db.Queries, tx TxStarter, hub *realtime.Hub, bus *events.Bus) *TaskService {
return &TaskService{Queries: q, TxStarter: tx, Hub: hub, Bus: bus}
type TaskWakeupNotifier interface {
NotifyTaskAvailable(runtimeID, taskID string)
}
func NewTaskService(q *db.Queries, tx TxStarter, hub *realtime.Hub, bus *events.Bus, wakeups ...TaskWakeupNotifier) *TaskService {
var wakeup TaskWakeupNotifier
if len(wakeups) > 0 {
wakeup = wakeups[0]
}
return &TaskService{Queries: q, TxStarter: tx, Hub: hub, Bus: bus, Wakeup: wakeup}
}
// EnqueueTaskForIssue creates a queued task for an agent-assigned issue.
@@ -73,6 +82,7 @@ func (s *TaskService) EnqueueTaskForIssue(ctx context.Context, issue db.Issue, t
}
slog.Info("task enqueued", "task_id", util.UUIDToString(task.ID), "issue_id", util.UUIDToString(issue.ID), "agent_id", util.UUIDToString(issue.AssigneeID))
s.notifyTaskAvailable(task)
return task, nil
}
@@ -107,6 +117,7 @@ func (s *TaskService) EnqueueTaskForMention(ctx context.Context, issue db.Issue,
}
slog.Info("mention task enqueued", "task_id", util.UUIDToString(task.ID), "issue_id", util.UUIDToString(issue.ID), "agent_id", util.UUIDToString(agentID))
s.notifyTaskAvailable(task)
return task, nil
}
@@ -137,6 +148,7 @@ func (s *TaskService) EnqueueChatTask(ctx context.Context, chatSession db.ChatSe
}
slog.Info("chat task enqueued", "task_id", util.UUIDToString(task.ID), "chat_session_id", util.UUIDToString(chatSession.ID), "agent_id", util.UUIDToString(chatSession.AgentID))
s.notifyTaskAvailable(task)
return task, nil
}
@@ -645,6 +657,7 @@ func (s *TaskService) MaybeRetryFailedTask(ctx context.Context, parent db.AgentT
"attempt", child.Attempt,
"max_attempts", child.MaxAttempts,
)
s.notifyTaskAvailable(child)
s.broadcastTaskEvent(ctx, protocol.EventTaskDispatch, child)
return &child, nil
}
@@ -902,6 +915,13 @@ func priorityToInt(p string) int32 {
}
}
func (s *TaskService) notifyTaskAvailable(task db.AgentTaskQueue) {
if s.Wakeup == nil || !task.RuntimeID.Valid {
return
}
s.Wakeup.NotifyTaskAvailable(util.UUIDToString(task.RuntimeID), util.UUIDToString(task.ID))
}
func (s *TaskService) broadcastTaskDispatch(ctx context.Context, task db.AgentTaskQueue) {
var payload map[string]any
if task.Context != nil {

View File

@@ -11,10 +11,10 @@ const (
EventCommentCreated = "comment:created"
EventCommentUpdated = "comment:updated"
EventCommentDeleted = "comment:deleted"
EventReactionAdded = "reaction:added"
EventReactionRemoved = "reaction:removed"
EventIssueReactionAdded = "issue_reaction:added"
EventIssueReactionRemoved = "issue_reaction:removed"
EventReactionAdded = "reaction:added"
EventReactionRemoved = "reaction:removed"
EventIssueReactionAdded = "issue_reaction:added"
EventIssueReactionRemoved = "issue_reaction:removed"
// Agent events
EventAgentStatus = "agent:status"
@@ -93,6 +93,7 @@ const (
EventAutopilotRunDone = "autopilot:run_done"
// Daemon events
EventDaemonHeartbeat = "daemon:heartbeat"
EventDaemonRegister = "daemon:register"
EventDaemonHeartbeat = "daemon:heartbeat"
EventDaemonRegister = "daemon:register"
EventDaemonTaskAvailable = "daemon:task_available"
)

View File

@@ -16,6 +16,13 @@ type TaskDispatchPayload struct {
Description string `json:"description"`
}
// TaskAvailablePayload is sent from server to daemon as a wakeup hint. The
// daemon still claims work through the existing HTTP claim endpoint.
type TaskAvailablePayload struct {
RuntimeID string `json:"runtime_id"`
TaskID string `json:"task_id,omitempty"`
}
// TaskProgressPayload is sent from daemon to server during task execution.
type TaskProgressPayload struct {
TaskID string `json:"task_id"`
@@ -37,10 +44,10 @@ type TaskMessagePayload struct {
IssueID string `json:"issue_id,omitempty"`
Seq int `json:"seq"`
Type string `json:"type"` // "text", "tool_use", "tool_result", "error"
Tool string `json:"tool,omitempty"` // tool name for tool_use/tool_result
Content string `json:"content,omitempty"` // text content
Input map[string]any `json:"input,omitempty"` // tool input (tool_use only)
Output string `json:"output,omitempty"` // tool output (tool_result only)
Tool string `json:"tool,omitempty"` // tool name for tool_use/tool_result
Content string `json:"content,omitempty"` // text content
Input map[string]any `json:"input,omitempty"` // tool input (tool_use only)
Output string `json:"output,omitempty"` // tool output (tool_result only)
}
// DaemonRegisterPayload is sent from daemon to server on connection.