mirror of
https://github.com/multica-ai/multica.git
synced 2026-06-17 11:48:42 +02:00
feat: add daemon websocket task wakeups (#1772)
* feat: add daemon websocket task wakeups * feat: fan out daemon wakeups across nodes * fix: dedupe daemon wakeup loopback events * fix: lengthen daemon polling fallback interval --------- Co-authored-by: Eve <eve@multica.ai>
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/multica-ai/multica/server/internal/daemonws"
|
||||
"github.com/multica-ai/multica/server/internal/realtime"
|
||||
)
|
||||
|
||||
@@ -47,7 +48,9 @@ func realtimeMetricsHandler(token string) http.HandlerFunc {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
_ = json.NewEncoder(w).Encode(realtime.M.Snapshot())
|
||||
snapshot := realtime.M.Snapshot()
|
||||
snapshot["daemonws"] = daemonws.M.Snapshot()
|
||||
_ = json.NewEncoder(w).Encode(snapshot)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/multica-ai/multica/server/internal/analytics"
|
||||
"github.com/multica-ai/multica/server/internal/daemonws"
|
||||
"github.com/multica-ai/multica/server/internal/events"
|
||||
"github.com/multica-ai/multica/server/internal/logger"
|
||||
obsmetrics "github.com/multica-ai/multica/server/internal/metrics"
|
||||
@@ -161,6 +162,8 @@ func main() {
|
||||
bus := events.New()
|
||||
hub := realtime.NewHub()
|
||||
go hub.Run()
|
||||
daemonHub := daemonws.NewHub()
|
||||
var daemonWakeup service.TaskWakeupNotifier = daemonHub
|
||||
|
||||
// MUL-1138: when REDIS_URL is set, route fanout through a Redis relay so
|
||||
// multiple API nodes can deliver each other's events. Without it the hub
|
||||
@@ -204,15 +207,21 @@ func main() {
|
||||
case "legacy":
|
||||
relayReadRedis = newNamedRedisClient(opts, "realtime-read")
|
||||
relay = realtime.NewRedisRelayWithClients(hub, relayWriteRedis, relayReadRedis)
|
||||
slog.Info("daemon websocket wakeup: Redis fanout disabled in legacy realtime relay mode")
|
||||
case "dual":
|
||||
shardedReadRedis = newNamedRedisClient(opts, "realtime-read-sharded")
|
||||
legacyReadRedis = newNamedRedisClient(opts, "realtime-read-legacy")
|
||||
sharded := realtime.NewShardedStreamRelay(hub, relayWriteRedis, shardedReadRedis, relayConfig)
|
||||
sharded.SetDaemonRuntimeDeliverer(daemonHub)
|
||||
legacy := realtime.NewRedisRelayWithClients(hub, relayWriteRedis, legacyReadRedis)
|
||||
relay = realtime.NewMirroredRelay(sharded, legacy)
|
||||
daemonWakeup = daemonws.NewRelayNotifier(daemonHub, sharded)
|
||||
default:
|
||||
relayReadRedis = newNamedRedisClient(opts, "realtime-read")
|
||||
relay = realtime.NewShardedStreamRelay(hub, relayWriteRedis, relayReadRedis, relayConfig)
|
||||
sharded := realtime.NewShardedStreamRelay(hub, relayWriteRedis, relayReadRedis, relayConfig)
|
||||
sharded.SetDaemonRuntimeDeliverer(daemonHub)
|
||||
relay = sharded
|
||||
daemonWakeup = daemonws.NewRelayNotifier(daemonHub, sharded)
|
||||
}
|
||||
relay.Start(relayCtx)
|
||||
broadcaster = realtime.NewDualWriteBroadcaster(hub, relay)
|
||||
@@ -253,6 +262,7 @@ func main() {
|
||||
metricsRegistry := obsmetrics.NewRegistry(obsmetrics.RegistryOptions{
|
||||
Pool: pool,
|
||||
Realtime: realtime.M,
|
||||
DaemonWS: daemonws.M,
|
||||
Version: version,
|
||||
Commit: commit,
|
||||
})
|
||||
@@ -267,7 +277,9 @@ func main() {
|
||||
}
|
||||
|
||||
r := NewRouterWithOptions(pool, hub, bus, analyticsClient, storeRedis, RouterOptions{
|
||||
HTTPMetrics: httpMetrics,
|
||||
HTTPMetrics: httpMetrics,
|
||||
DaemonHub: daemonHub,
|
||||
DaemonWakeup: daemonWakeup,
|
||||
})
|
||||
|
||||
srv := &http.Server{
|
||||
@@ -278,7 +290,7 @@ func main() {
|
||||
// Start background workers.
|
||||
sweepCtx, sweepCancel := context.WithCancel(context.Background())
|
||||
autopilotCtx, autopilotCancel := context.WithCancel(context.Background())
|
||||
taskSvc := service.NewTaskService(queries, pool, hub, bus)
|
||||
taskSvc := service.NewTaskService(queries, pool, hub, bus, daemonWakeup)
|
||||
autopilotSvc := service.NewAutopilotService(queries, pool, bus, taskSvc)
|
||||
registerAutopilotListeners(bus, autopilotSvc)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/multica-ai/multica/server/internal/analytics"
|
||||
"github.com/multica-ai/multica/server/internal/auth"
|
||||
"github.com/multica-ai/multica/server/internal/daemonws"
|
||||
"github.com/multica-ai/multica/server/internal/events"
|
||||
"github.com/multica-ai/multica/server/internal/handler"
|
||||
obsmetrics "github.com/multica-ai/multica/server/internal/metrics"
|
||||
@@ -67,12 +68,18 @@ func NewRouter(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus, analytics
|
||||
}
|
||||
|
||||
type RouterOptions struct {
|
||||
HTTPMetrics *obsmetrics.HTTPMetrics
|
||||
HTTPMetrics *obsmetrics.HTTPMetrics
|
||||
DaemonHub *daemonws.Hub
|
||||
DaemonWakeup service.TaskWakeupNotifier
|
||||
}
|
||||
|
||||
func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus, analyticsClient analytics.Client, rdb *redis.Client, opts RouterOptions) chi.Router {
|
||||
queries := db.New(pool)
|
||||
emailSvc := service.NewEmailService()
|
||||
daemonHub := opts.DaemonHub
|
||||
if daemonHub == nil {
|
||||
daemonHub = daemonws.NewHub()
|
||||
}
|
||||
|
||||
// Initialize storage with S3 as primary, fallback to local
|
||||
var store storage.Storage
|
||||
@@ -93,7 +100,10 @@ func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus
|
||||
AllowedEmails: splitAndTrim(os.Getenv("ALLOWED_EMAILS")),
|
||||
AllowedEmailDomains: splitAndTrim(os.Getenv("ALLOWED_EMAIL_DOMAINS")),
|
||||
}
|
||||
h := handler.New(queries, pool, hub, bus, emailSvc, store, cfSigner, analyticsClient, signupConfig)
|
||||
h := handler.New(queries, pool, hub, bus, emailSvc, store, cfSigner, analyticsClient, signupConfig, daemonHub)
|
||||
if opts.DaemonWakeup != nil {
|
||||
h.TaskService.Wakeup = opts.DaemonWakeup
|
||||
}
|
||||
if rdb != nil {
|
||||
h.LocalSkillListStore = handler.NewRedisLocalSkillListStore(rdb)
|
||||
h.LocalSkillImportStore = handler.NewRedisLocalSkillImportStore(rdb)
|
||||
@@ -178,6 +188,7 @@ func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus
|
||||
r.Post("/register", h.DaemonRegister)
|
||||
r.Post("/deregister", h.DaemonDeregister)
|
||||
r.Post("/heartbeat", h.DaemonHeartbeat)
|
||||
r.Get("/ws", h.DaemonWebSocket)
|
||||
r.Get("/workspaces/{workspaceId}/repos", h.GetDaemonWorkspaceRepos)
|
||||
|
||||
r.Post("/runtimes/{runtimeId}/tasks/claim", h.ClaimTaskByRuntime)
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
const (
|
||||
DefaultServerURL = "ws://localhost:8080/ws"
|
||||
DefaultPollInterval = 3 * time.Second
|
||||
DefaultPollInterval = 30 * time.Second
|
||||
DefaultHeartbeatInterval = 15 * time.Second
|
||||
DefaultAgentTimeout = 2 * time.Hour
|
||||
DefaultCodexSemanticInactivityTimeout = 10 * time.Minute
|
||||
|
||||
@@ -45,6 +45,7 @@ type Daemon struct {
|
||||
workspaces map[string]*workspaceState
|
||||
runtimeIndex map[string]Runtime // runtimeID -> Runtime for provider lookups
|
||||
reloading sync.Mutex // prevents concurrent workspace syncs
|
||||
runtimeSetCh chan struct{} // notifies the WS wakeup loop to reconnect with a new runtime set
|
||||
|
||||
versionsMu sync.RWMutex // guards agentVersions
|
||||
agentVersions map[string]string // provider -> detected CLI version (set during registration)
|
||||
@@ -69,6 +70,7 @@ func New(cfg Config, logger *slog.Logger) *Daemon {
|
||||
logger: logger,
|
||||
workspaces: make(map[string]*workspaceState),
|
||||
runtimeIndex: make(map[string]Runtime),
|
||||
runtimeSetCh: make(chan struct{}, 1),
|
||||
agentVersions: make(map[string]string),
|
||||
}
|
||||
}
|
||||
@@ -89,6 +91,23 @@ func (d *Daemon) agentVersion(provider string) string {
|
||||
return d.agentVersions[provider]
|
||||
}
|
||||
|
||||
func (d *Daemon) notifyRuntimeSetChanged() {
|
||||
select {
|
||||
case d.runtimeSetCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) drainRuntimeSetChanged() {
|
||||
for {
|
||||
select {
|
||||
case <-d.runtimeSetCh:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run starts the daemon: resolves auth, registers runtimes, then polls for tasks.
|
||||
func (d *Daemon) Run(ctx context.Context) error {
|
||||
// Wrap context so handleUpdate can cancel the daemon for restart.
|
||||
@@ -132,10 +151,13 @@ func (d *Daemon) Run(ctx context.Context) error {
|
||||
// Start workspace sync loop to discover newly created workspaces.
|
||||
go d.workspaceSyncLoop(ctx)
|
||||
|
||||
taskWakeups := make(chan struct{}, 1)
|
||||
d.drainRuntimeSetChanged()
|
||||
go d.taskWakeupLoop(ctx, taskWakeups)
|
||||
go d.heartbeatLoop(ctx)
|
||||
go d.gcLoop(ctx)
|
||||
go d.serveHealth(ctx, healthLn, time.Now())
|
||||
return d.pollLoop(ctx)
|
||||
return d.pollLoop(ctx, taskWakeups)
|
||||
}
|
||||
|
||||
// RestartBinary returns the path to the new binary if the daemon needs to restart
|
||||
@@ -422,6 +444,7 @@ func (d *Daemon) syncWorkspacesFromAPI(ctx context.Context) error {
|
||||
d.mu.Unlock()
|
||||
|
||||
var registered int
|
||||
var removed int
|
||||
for id, name := range apiIDs {
|
||||
if currentIDs[id] {
|
||||
continue // important: never replace existing workspaceState; ensureRepoReady holds ws.repoRefreshMu from the original pointer
|
||||
@@ -473,8 +496,12 @@ func (d *Daemon) syncWorkspacesFromAPI(ctx context.Context) error {
|
||||
delete(d.workspaces, id)
|
||||
d.mu.Unlock()
|
||||
d.logger.Info("stopped watching workspace", "workspace_id", id)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
if registered > 0 || removed > 0 {
|
||||
d.notifyRuntimeSetChanged()
|
||||
}
|
||||
|
||||
if len(d.allRuntimeIDs()) == 0 && registered == 0 && len(workspaces) > 0 {
|
||||
return fmt.Errorf("failed to register runtimes for any of the %d workspace(s)", len(workspaces))
|
||||
@@ -799,7 +826,7 @@ func (d *Daemon) triggerRestart() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) pollLoop(ctx context.Context) error {
|
||||
func (d *Daemon) pollLoop(ctx context.Context, taskWakeups <-chan struct{}) error {
|
||||
sem := make(chan struct{}, d.cfg.MaxConcurrentTasks)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
@@ -822,7 +849,7 @@ func (d *Daemon) pollLoop(ctx context.Context) error {
|
||||
|
||||
runtimeIDs := d.allRuntimeIDs()
|
||||
if len(runtimeIDs) == 0 {
|
||||
if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil {
|
||||
if err := sleepWithContextOrWakeup(ctx, d.cfg.PollInterval, taskWakeups); err != nil {
|
||||
wg.Wait()
|
||||
return err
|
||||
}
|
||||
@@ -878,7 +905,7 @@ func (d *Daemon) pollLoop(ctx context.Context) error {
|
||||
d.logger.Debug("poll: no tasks", "runtimes", runtimeIDs, "cycle", pollCount)
|
||||
}
|
||||
pollOffset = (pollOffset + 1) % n
|
||||
if err := sleepWithContext(ctx, d.cfg.PollInterval); err != nil {
|
||||
if err := sleepWithContextOrWakeup(ctx, d.cfg.PollInterval, taskWakeups); err != nil {
|
||||
wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -78,3 +78,21 @@ func sleepWithContext(ctx context.Context, d time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sleepWithContextOrWakeup(ctx context.Context, d time.Duration, wakeups <-chan struct{}) error {
|
||||
if wakeups == nil {
|
||||
return sleepWithContext(ctx, d)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-wakeups:
|
||||
return nil
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
188
server/internal/daemon/wakeup.go
Normal file
188
server/internal/daemon/wakeup.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/multica-ai/multica/server/pkg/protocol"
|
||||
)
|
||||
|
||||
var errRuntimeSetChanged = errors.New("runtime set changed")
|
||||
|
||||
func (d *Daemon) taskWakeupLoop(ctx context.Context, taskWakeups chan<- struct{}) {
|
||||
backoff := time.Second
|
||||
|
||||
for {
|
||||
runtimeIDs := d.allRuntimeIDs()
|
||||
if len(runtimeIDs) == 0 {
|
||||
if err := sleepWithContextOrRuntimeChange(ctx, 5*time.Second, d.runtimeSetCh); err != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
err := d.runTaskWakeupConnection(ctx, runtimeIDs, taskWakeups)
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, errRuntimeSetChanged) {
|
||||
backoff = time.Second
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
d.logger.Debug("task wakeup websocket unavailable; polling fallback remains active", "error", err, "retry_in", backoff)
|
||||
}
|
||||
|
||||
if err := sleepWithContextOrRuntimeChange(ctx, jitterDuration(backoff), d.runtimeSetCh); err != nil {
|
||||
return
|
||||
}
|
||||
if backoff < 30*time.Second {
|
||||
backoff *= 2
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func jitterDuration(d time.Duration) time.Duration {
|
||||
if d <= 0 {
|
||||
return d
|
||||
}
|
||||
spread := d / 5
|
||||
if spread <= 0 {
|
||||
return d
|
||||
}
|
||||
delta := time.Duration(rand.Int63n(int64(spread)*2+1)) - spread
|
||||
return d + delta
|
||||
}
|
||||
|
||||
func (d *Daemon) runTaskWakeupConnection(ctx context.Context, runtimeIDs []string, taskWakeups chan<- struct{}) error {
|
||||
wsURL, err := taskWakeupURL(d.cfg.ServerBaseURL, runtimeIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headers := http.Header{}
|
||||
if token := d.client.Token(); token != "" {
|
||||
headers.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
if d.client.platform != "" {
|
||||
headers.Set("X-Client-Platform", d.client.platform)
|
||||
}
|
||||
if d.client.version != "" {
|
||||
headers.Set("X-Client-Version", d.client.version)
|
||||
}
|
||||
if d.client.os != "" {
|
||||
headers.Set("X-Client-OS", d.client.os)
|
||||
}
|
||||
|
||||
dialer := websocket.Dialer{HandshakeTimeout: 10 * time.Second}
|
||||
conn, _, err := dialer.DialContext(ctx, wsURL, headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
d.logger.Info("task wakeup websocket connected", "runtimes", len(runtimeIDs))
|
||||
signalTaskWakeup(taskWakeups)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- d.readTaskWakeupMessages(conn, taskWakeups)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-d.runtimeSetCh:
|
||||
return errRuntimeSetChanged
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) readTaskWakeupMessages(conn *websocket.Conn, taskWakeups chan<- struct{}) error {
|
||||
conn.SetReadLimit(64 * 1024)
|
||||
for {
|
||||
_, raw, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var msg protocol.Message
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
d.logger.Debug("task wakeup websocket invalid message", "error", err)
|
||||
continue
|
||||
}
|
||||
if msg.Type != protocol.EventDaemonTaskAvailable {
|
||||
continue
|
||||
}
|
||||
var payload protocol.TaskAvailablePayload
|
||||
if len(msg.Payload) > 0 {
|
||||
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
|
||||
d.logger.Debug("task wakeup websocket invalid payload", "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if payload.RuntimeID != "" {
|
||||
d.logger.Debug("task wakeup received", "runtime_id", payload.RuntimeID, "task_id", payload.TaskID)
|
||||
}
|
||||
signalTaskWakeup(taskWakeups)
|
||||
}
|
||||
}
|
||||
|
||||
func signalTaskWakeup(taskWakeups chan<- struct{}) {
|
||||
select {
|
||||
case taskWakeups <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func taskWakeupURL(baseURL string, runtimeIDs []string) (string, error) {
|
||||
u, err := url.Parse(strings.TrimSpace(baseURL))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid daemon server URL: %w", err)
|
||||
}
|
||||
switch u.Scheme {
|
||||
case "http":
|
||||
u.Scheme = "ws"
|
||||
case "https":
|
||||
u.Scheme = "wss"
|
||||
case "ws", "wss":
|
||||
default:
|
||||
return "", fmt.Errorf("daemon server URL must use http, https, ws, or wss")
|
||||
}
|
||||
|
||||
u.Path = strings.TrimRight(u.Path, "/") + "/api/daemon/ws"
|
||||
u.RawPath = ""
|
||||
q := u.Query()
|
||||
ids := append([]string(nil), runtimeIDs...)
|
||||
sort.Strings(ids)
|
||||
q.Set("runtime_ids", strings.Join(ids, ","))
|
||||
u.RawQuery = q.Encode()
|
||||
u.Fragment = ""
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func sleepWithContextOrRuntimeChange(ctx context.Context, d time.Duration, runtimeSetCh <-chan struct{}) error {
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-runtimeSetCh:
|
||||
return nil
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
43
server/internal/daemon/wakeup_test.go
Normal file
43
server/internal/daemon/wakeup_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package daemon
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTaskWakeupURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
runtimeIDs []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "http base",
|
||||
baseURL: "http://localhost:8080",
|
||||
runtimeIDs: []string{"runtime-b", "runtime-a"},
|
||||
want: "ws://localhost:8080/api/daemon/ws?runtime_ids=runtime-a%2Cruntime-b",
|
||||
},
|
||||
{
|
||||
name: "https base",
|
||||
baseURL: "https://api.example.com",
|
||||
runtimeIDs: []string{"runtime-1"},
|
||||
want: "wss://api.example.com/api/daemon/ws?runtime_ids=runtime-1",
|
||||
},
|
||||
{
|
||||
name: "base path",
|
||||
baseURL: "https://api.example.com/multica",
|
||||
runtimeIDs: []string{"runtime-1"},
|
||||
want: "wss://api.example.com/multica/api/daemon/ws?runtime_ids=runtime-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := taskWakeupURL(tt.baseURL, tt.runtimeIDs)
|
||||
if err != nil {
|
||||
t.Fatalf("taskWakeupURL: %v", err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Fatalf("taskWakeupURL() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
349
server/internal/daemonws/hub.go
Normal file
349
server/internal/daemonws/hub.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package daemonws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/multica-ai/multica/server/pkg/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
writeWait = 10 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
)
|
||||
|
||||
// ClientIdentity captures the already-authenticated daemon connection scope.
|
||||
type ClientIdentity struct {
|
||||
DaemonID string
|
||||
UserID string
|
||||
WorkspaceID string
|
||||
RuntimeIDs []string
|
||||
ClientVersion string
|
||||
}
|
||||
|
||||
type client struct {
|
||||
hub *Hub
|
||||
conn *websocket.Conn
|
||||
send chan []byte
|
||||
identity ClientIdentity
|
||||
runtimes map[string]struct{}
|
||||
|
||||
dedupMu sync.Mutex
|
||||
seenIDs map[string]struct{}
|
||||
seenList []string
|
||||
}
|
||||
|
||||
const eventDedupCapacity = 128
|
||||
|
||||
// markSeen records eventID as already delivered to this client. Empty event IDs
|
||||
// disable dedup and are always delivered.
|
||||
func (c *client) markSeen(eventID string) bool {
|
||||
if eventID == "" {
|
||||
return true
|
||||
}
|
||||
c.dedupMu.Lock()
|
||||
defer c.dedupMu.Unlock()
|
||||
if c.seenIDs == nil {
|
||||
c.seenIDs = make(map[string]struct{}, eventDedupCapacity)
|
||||
}
|
||||
if _, ok := c.seenIDs[eventID]; ok {
|
||||
return false
|
||||
}
|
||||
c.seenIDs[eventID] = struct{}{}
|
||||
c.seenList = append(c.seenList, eventID)
|
||||
if len(c.seenList) > eventDedupCapacity {
|
||||
drop := c.seenList[0]
|
||||
c.seenList = c.seenList[1:]
|
||||
delete(c.seenIDs, drop)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Hub keeps daemon WebSocket connections indexed by runtime ID. Messages are
|
||||
// best-effort wakeup hints; the daemon still uses HTTP claim for correctness.
|
||||
type Hub struct {
|
||||
upgrader websocket.Upgrader
|
||||
|
||||
mu sync.RWMutex
|
||||
clients map[*client]bool
|
||||
byRuntime map[string]map[*client]bool
|
||||
}
|
||||
|
||||
func NewHub() *Hub {
|
||||
return &Hub{
|
||||
upgrader: websocket.Upgrader{
|
||||
// Daemon clients authenticate with Authorization headers before the
|
||||
// upgrade. Browsers cannot set those headers through the native WS API,
|
||||
// and DaemonAuth does not accept cookies, so cookie-based CSWSH does
|
||||
// not apply to this endpoint. Re-evaluate this if DaemonAuth ever
|
||||
// grows cookie fallback.
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
},
|
||||
clients: make(map[*client]bool),
|
||||
byRuntime: make(map[string]map[*client]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) HandleWebSocket(w http.ResponseWriter, r *http.Request, identity ClientIdentity) {
|
||||
if len(identity.RuntimeIDs) == 0 {
|
||||
http.Error(w, `{"error":"runtime_ids required"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := h.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
slog.Error("daemon websocket upgrade failed", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
runtimes := make(map[string]struct{}, len(identity.RuntimeIDs))
|
||||
for _, runtimeID := range identity.RuntimeIDs {
|
||||
if runtimeID != "" {
|
||||
runtimes[runtimeID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(runtimes) == 0 {
|
||||
conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"runtime_ids required"}`))
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
c := &client{
|
||||
hub: h,
|
||||
conn: conn,
|
||||
send: make(chan []byte, 16),
|
||||
identity: identity,
|
||||
runtimes: runtimes,
|
||||
}
|
||||
h.register(c)
|
||||
|
||||
go c.writePump()
|
||||
go c.readPump()
|
||||
}
|
||||
|
||||
// NotifyTaskAvailable sends a best-effort wakeup to daemons watching runtimeID.
|
||||
func (h *Hub) NotifyTaskAvailable(runtimeID, taskID string) {
|
||||
h.notifyTaskAvailable(runtimeID, taskID, "")
|
||||
}
|
||||
|
||||
func (h *Hub) notifyTaskAvailable(runtimeID, taskID, eventID string) {
|
||||
if h == nil || runtimeID == "" {
|
||||
return
|
||||
}
|
||||
data, err := taskAvailableFrame(runtimeID, taskID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
delivered, deduped := h.notifyFrame(runtimeID, data, eventID)
|
||||
if delivered {
|
||||
M.WakeupDeliveredHit.Add(1)
|
||||
} else if !deduped {
|
||||
M.WakeupDeliveredMiss.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) DeliverDaemonRuntime(scopeID string, frame []byte, eventID string) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
M.WakeupReceivedTotal.Add(1)
|
||||
var msg protocol.Message
|
||||
if err := json.Unmarshal(frame, &msg); err != nil {
|
||||
slog.Debug("daemon websocket relay: invalid frame", "error", err, "scope_id", scopeID, "event_id", eventID)
|
||||
M.WakeupDeliveredMiss.Add(1)
|
||||
return
|
||||
}
|
||||
if msg.Type != protocol.EventDaemonTaskAvailable {
|
||||
M.WakeupDeliveredMiss.Add(1)
|
||||
return
|
||||
}
|
||||
var payload protocol.TaskAvailablePayload
|
||||
if err := json.Unmarshal(msg.Payload, &payload); err != nil || payload.RuntimeID == "" {
|
||||
slog.Debug("daemon websocket relay: invalid task_available payload", "error", err, "scope_id", scopeID, "event_id", eventID)
|
||||
M.WakeupDeliveredMiss.Add(1)
|
||||
return
|
||||
}
|
||||
delivered, deduped := h.notifyFrame(payload.RuntimeID, frame, eventID)
|
||||
if delivered {
|
||||
M.WakeupDeliveredHit.Add(1)
|
||||
} else if !deduped {
|
||||
M.WakeupDeliveredMiss.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) notifyFrame(runtimeID string, data []byte, eventID string) (delivered bool, deduped bool) {
|
||||
h.mu.RLock()
|
||||
clients := h.byRuntime[runtimeID]
|
||||
slow := make([]*client, 0)
|
||||
for c := range clients {
|
||||
if !c.markSeen(eventID) {
|
||||
deduped = true
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case c.send <- data:
|
||||
delivered = true
|
||||
default:
|
||||
slow = append(slow, c)
|
||||
}
|
||||
}
|
||||
h.mu.RUnlock()
|
||||
|
||||
for _, c := range slow {
|
||||
h.unregister(c)
|
||||
c.conn.Close()
|
||||
}
|
||||
if len(slow) > 0 {
|
||||
M.SlowEvictionsTotal.Add(int64(len(slow)))
|
||||
}
|
||||
return delivered, deduped
|
||||
}
|
||||
|
||||
func taskAvailableFrame(runtimeID, taskID string) ([]byte, error) {
|
||||
return json.Marshal(protocol.Message{
|
||||
Type: protocol.EventDaemonTaskAvailable,
|
||||
Payload: mustMarshalRaw(protocol.TaskAvailablePayload{
|
||||
RuntimeID: runtimeID,
|
||||
TaskID: taskID,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
func mustMarshalRaw(v any) json.RawMessage {
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (h *Hub) RuntimeConnectionCount(runtimeID string) int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.byRuntime[runtimeID])
|
||||
}
|
||||
|
||||
func (h *Hub) register(c *client) {
|
||||
h.mu.Lock()
|
||||
h.clients[c] = true
|
||||
for runtimeID := range c.runtimes {
|
||||
conns := h.byRuntime[runtimeID]
|
||||
if conns == nil {
|
||||
conns = make(map[*client]bool)
|
||||
h.byRuntime[runtimeID] = conns
|
||||
}
|
||||
conns[c] = true
|
||||
}
|
||||
total := len(h.clients)
|
||||
h.mu.Unlock()
|
||||
|
||||
M.ConnectsTotal.Add(1)
|
||||
M.ActiveConnections.Add(1)
|
||||
slog.Info("daemon websocket connected",
|
||||
"daemon_id", c.identity.DaemonID,
|
||||
"user_id", c.identity.UserID,
|
||||
"workspace_id", c.identity.WorkspaceID,
|
||||
"runtimes", len(c.runtimes),
|
||||
"client_version", c.identity.ClientVersion,
|
||||
"total_clients", total,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *Hub) unregister(c *client) {
|
||||
h.mu.Lock()
|
||||
if !h.clients[c] {
|
||||
h.mu.Unlock()
|
||||
return
|
||||
}
|
||||
delete(h.clients, c)
|
||||
for runtimeID := range c.runtimes {
|
||||
if conns := h.byRuntime[runtimeID]; conns != nil {
|
||||
delete(conns, c)
|
||||
if len(conns) == 0 {
|
||||
delete(h.byRuntime, runtimeID)
|
||||
}
|
||||
}
|
||||
}
|
||||
close(c.send)
|
||||
total := len(h.clients)
|
||||
h.mu.Unlock()
|
||||
|
||||
M.DisconnectsTotal.Add(1)
|
||||
M.ActiveConnections.Add(-1)
|
||||
slog.Info("daemon websocket disconnected",
|
||||
"daemon_id", c.identity.DaemonID,
|
||||
"user_id", c.identity.UserID,
|
||||
"workspace_id", c.identity.WorkspaceID,
|
||||
"runtimes", len(c.runtimes),
|
||||
"total_clients", total,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *client) readPump() {
|
||||
defer func() {
|
||||
c.hub.unregister(c)
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
c.conn.SetReadLimit(4096)
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
c.conn.SetPongHandler(func(string) error {
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
_, raw, err := c.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
||||
slog.Debug("daemon websocket read error", "error", err, "daemon_id", c.identity.DaemonID)
|
||||
}
|
||||
return
|
||||
}
|
||||
c.handleFrame(raw)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handleFrame(raw []byte) {
|
||||
var msg protocol.Message
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
slog.Debug("daemon websocket invalid frame", "error", err, "daemon_id", c.identity.DaemonID)
|
||||
return
|
||||
}
|
||||
// The phase-one daemon channel is server-push only. Inbound frames are
|
||||
// drained so control frames and close handling work, but app messages are
|
||||
// intentionally ignored for forward compatibility.
|
||||
}
|
||||
|
||||
func (c *client) writePump() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.send:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if !ok {
|
||||
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
slog.Debug("daemon websocket write error", "error", err, "daemon_id", c.identity.DaemonID)
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
200
server/internal/daemonws/hub_test.go
Normal file
200
server/internal/daemonws/hub_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package daemonws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/multica-ai/multica/server/internal/realtime"
|
||||
"github.com/multica-ai/multica/server/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestNotifyTaskAvailable(t *testing.T) {
|
||||
M.Reset()
|
||||
defer M.Reset()
|
||||
|
||||
hub := NewHub()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hub.HandleWebSocket(w, r, ClientIdentity{RuntimeIDs: []string{"runtime-1"}})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for hub.RuntimeConnectionCount("runtime-1") == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatal("runtime connection was not registered")
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
hub.NotifyTaskAvailable("runtime-1", "task-1")
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
|
||||
t.Fatalf("SetReadDeadline: %v", err)
|
||||
}
|
||||
_, raw, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadMessage: %v", err)
|
||||
}
|
||||
|
||||
var msg protocol.Message
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
t.Fatalf("unmarshal message: %v", err)
|
||||
}
|
||||
if msg.Type != protocol.EventDaemonTaskAvailable {
|
||||
t.Fatalf("message type = %q, want %q", msg.Type, protocol.EventDaemonTaskAvailable)
|
||||
}
|
||||
|
||||
var payload protocol.TaskAvailablePayload
|
||||
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
|
||||
t.Fatalf("unmarshal payload: %v", err)
|
||||
}
|
||||
if payload.RuntimeID != "runtime-1" || payload.TaskID != "task-1" {
|
||||
t.Fatalf("payload = %+v, want runtime/task IDs", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelayNotifierPublishesDaemonRuntimeScope(t *testing.T) {
|
||||
M.Reset()
|
||||
defer M.Reset()
|
||||
|
||||
relay := &recordingRelayPublisher{}
|
||||
notifier := NewRelayNotifier(nil, relay)
|
||||
|
||||
notifier.NotifyTaskAvailable("runtime-1", "task-1")
|
||||
|
||||
if relay.scopeType != realtime.ScopeDaemonRuntime {
|
||||
t.Fatalf("scopeType = %q, want %q", relay.scopeType, realtime.ScopeDaemonRuntime)
|
||||
}
|
||||
if relay.scopeID != "task-1" {
|
||||
t.Fatalf("scopeID = %q, want task_id shard key", relay.scopeID)
|
||||
}
|
||||
if relay.eventID == "" {
|
||||
t.Fatal("expected event id")
|
||||
}
|
||||
if M.WakeupPublishedTotal.Load() != 1 {
|
||||
t.Fatalf("published metric = %d, want 1", M.WakeupPublishedTotal.Load())
|
||||
}
|
||||
|
||||
var msg protocol.Message
|
||||
if err := json.Unmarshal(relay.frame, &msg); err != nil {
|
||||
t.Fatalf("unmarshal frame: %v", err)
|
||||
}
|
||||
if msg.Type != protocol.EventDaemonTaskAvailable {
|
||||
t.Fatalf("message type = %q, want %q", msg.Type, protocol.EventDaemonTaskAvailable)
|
||||
}
|
||||
var payload protocol.TaskAvailablePayload
|
||||
if err := json.Unmarshal(msg.Payload, &payload); err != nil {
|
||||
t.Fatalf("unmarshal payload: %v", err)
|
||||
}
|
||||
if payload.RuntimeID != "runtime-1" || payload.TaskID != "task-1" {
|
||||
t.Fatalf("payload = %+v, want runtime/task IDs", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelayNotifierDedupsLocalRedisLoopback(t *testing.T) {
|
||||
M.Reset()
|
||||
defer M.Reset()
|
||||
|
||||
hub := NewHub()
|
||||
client := attachDaemonTestClient(hub, "runtime-1")
|
||||
relay := &localFirstDaemonRelayPublisher{t: t, client: client}
|
||||
notifier := NewRelayNotifier(hub, relay)
|
||||
|
||||
notifier.NotifyTaskAvailable("runtime-1", "task-1")
|
||||
|
||||
if !relay.called {
|
||||
t.Fatal("expected relay publish to be invoked")
|
||||
}
|
||||
if relay.eventID == "" {
|
||||
t.Fatal("expected event id")
|
||||
}
|
||||
if M.WakeupDeliveredHit.Load() != 1 {
|
||||
t.Fatalf("delivered hit metric = %d, want 1", M.WakeupDeliveredHit.Load())
|
||||
}
|
||||
|
||||
hub.DeliverDaemonRuntime(relay.scopeID, relay.frame, relay.eventID)
|
||||
|
||||
select {
|
||||
case duplicate := <-client.send:
|
||||
t.Fatalf("expected redis loopback to be deduped, got duplicate %s", duplicate)
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
}
|
||||
if M.WakeupDeliveredHit.Load() != 1 {
|
||||
t.Fatalf("delivered hit metric after loopback = %d, want 1", M.WakeupDeliveredHit.Load())
|
||||
}
|
||||
if M.WakeupDeliveredMiss.Load() != 0 {
|
||||
t.Fatalf("delivered miss metric after dedup = %d, want 0", M.WakeupDeliveredMiss.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func attachDaemonTestClient(hub *Hub, runtimeID string) *client {
|
||||
c := &client{
|
||||
send: make(chan []byte, 2),
|
||||
runtimes: map[string]struct{}{runtimeID: {}},
|
||||
}
|
||||
|
||||
hub.mu.Lock()
|
||||
hub.clients[c] = true
|
||||
hub.byRuntime[runtimeID] = map[*client]bool{c: true}
|
||||
hub.mu.Unlock()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
type recordingRelayPublisher struct {
|
||||
scopeType string
|
||||
scopeID string
|
||||
exclude string
|
||||
frame []byte
|
||||
eventID string
|
||||
}
|
||||
|
||||
func (r *recordingRelayPublisher) PublishWithID(scopeType, scopeID, exclude string, frame []byte, id string) error {
|
||||
r.scopeType = scopeType
|
||||
r.scopeID = scopeID
|
||||
r.exclude = exclude
|
||||
r.frame = append([]byte(nil), frame...)
|
||||
r.eventID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
type localFirstDaemonRelayPublisher struct {
|
||||
t *testing.T
|
||||
client *client
|
||||
|
||||
called bool
|
||||
scopeType string
|
||||
scopeID string
|
||||
exclude string
|
||||
frame []byte
|
||||
eventID string
|
||||
localFrame []byte
|
||||
}
|
||||
|
||||
func (p *localFirstDaemonRelayPublisher) PublishWithID(scopeType, scopeID, exclude string, frame []byte, id string) error {
|
||||
p.called = true
|
||||
p.scopeType = scopeType
|
||||
p.scopeID = scopeID
|
||||
p.exclude = exclude
|
||||
p.frame = append([]byte(nil), frame...)
|
||||
p.eventID = id
|
||||
|
||||
select {
|
||||
case p.localFrame = <-p.client.send:
|
||||
default:
|
||||
p.t.Fatal("expected local fanout to happen before relay publish")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
44
server/internal/daemonws/metrics.go
Normal file
44
server/internal/daemonws/metrics.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package daemonws
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type Metrics struct {
|
||||
ConnectsTotal atomic.Int64
|
||||
DisconnectsTotal atomic.Int64
|
||||
ActiveConnections atomic.Int64
|
||||
SlowEvictionsTotal atomic.Int64
|
||||
|
||||
WakeupPublishedTotal atomic.Int64
|
||||
WakeupPublishErrors atomic.Int64
|
||||
WakeupReceivedTotal atomic.Int64
|
||||
WakeupDeliveredHit atomic.Int64
|
||||
WakeupDeliveredMiss atomic.Int64
|
||||
}
|
||||
|
||||
var M = &Metrics{}
|
||||
|
||||
func (m *Metrics) Snapshot() map[string]any {
|
||||
return map[string]any{
|
||||
"connects_total": m.ConnectsTotal.Load(),
|
||||
"disconnects_total": m.DisconnectsTotal.Load(),
|
||||
"active_connections": m.ActiveConnections.Load(),
|
||||
"slow_evictions_total": m.SlowEvictionsTotal.Load(),
|
||||
"wakeup_published_total": m.WakeupPublishedTotal.Load(),
|
||||
"wakeup_publish_errors": m.WakeupPublishErrors.Load(),
|
||||
"wakeup_received_total": m.WakeupReceivedTotal.Load(),
|
||||
"wakeup_delivered_hit_total": m.WakeupDeliveredHit.Load(),
|
||||
"wakeup_delivered_miss_total": m.WakeupDeliveredMiss.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Metrics) Reset() {
|
||||
m.ConnectsTotal.Store(0)
|
||||
m.DisconnectsTotal.Store(0)
|
||||
m.ActiveConnections.Store(0)
|
||||
m.SlowEvictionsTotal.Store(0)
|
||||
m.WakeupPublishedTotal.Store(0)
|
||||
m.WakeupPublishErrors.Store(0)
|
||||
m.WakeupReceivedTotal.Store(0)
|
||||
m.WakeupDeliveredHit.Store(0)
|
||||
m.WakeupDeliveredMiss.Store(0)
|
||||
}
|
||||
49
server/internal/daemonws/notifier.go
Normal file
49
server/internal/daemonws/notifier.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package daemonws
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/oklog/ulid/v2"
|
||||
|
||||
"github.com/multica-ai/multica/server/internal/realtime"
|
||||
)
|
||||
|
||||
// RelayNotifier sends task wakeups to the local daemon hub and, when Redis is
|
||||
// configured, publishes the same wakeup through the shared realtime relay so
|
||||
// every API node can attempt local delivery.
|
||||
type RelayNotifier struct {
|
||||
local *Hub
|
||||
relay realtime.RelayPublisher
|
||||
}
|
||||
|
||||
func NewRelayNotifier(local *Hub, relay realtime.RelayPublisher) *RelayNotifier {
|
||||
return &RelayNotifier{local: local, relay: relay}
|
||||
}
|
||||
|
||||
func (n *RelayNotifier) NotifyTaskAvailable(runtimeID, taskID string) {
|
||||
if runtimeID == "" {
|
||||
return
|
||||
}
|
||||
eventID := ulid.Make().String()
|
||||
if n.local != nil {
|
||||
n.local.notifyTaskAvailable(runtimeID, taskID, eventID)
|
||||
}
|
||||
if n.relay == nil {
|
||||
return
|
||||
}
|
||||
frame, err := taskAvailableFrame(runtimeID, taskID)
|
||||
if err != nil {
|
||||
M.WakeupPublishErrors.Add(1)
|
||||
return
|
||||
}
|
||||
shardKey := taskID
|
||||
if shardKey == "" {
|
||||
shardKey = eventID
|
||||
}
|
||||
if err := n.relay.PublishWithID(realtime.ScopeDaemonRuntime, shardKey, "", frame, eventID); err != nil {
|
||||
M.WakeupPublishErrors.Add(1)
|
||||
slog.Warn("daemon websocket wakeup publish failed", "error", err, "runtime_id", runtimeID, "task_id", taskID)
|
||||
return
|
||||
}
|
||||
M.WakeupPublishedTotal.Add(1)
|
||||
}
|
||||
66
server/internal/handler/daemon_ws.go
Normal file
66
server/internal/handler/daemon_ws.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/multica-ai/multica/server/internal/daemonws"
|
||||
"github.com/multica-ai/multica/server/internal/middleware"
|
||||
)
|
||||
|
||||
func (h *Handler) DaemonWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
if h.DaemonHub == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "daemon websocket unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
runtimeIDs := parseRuntimeIDs(r)
|
||||
if len(runtimeIDs) == 0 {
|
||||
writeError(w, http.StatusBadRequest, "runtime_ids required")
|
||||
return
|
||||
}
|
||||
|
||||
for _, runtimeID := range runtimeIDs {
|
||||
rt, ok := h.requireDaemonRuntimeAccess(w, r, runtimeID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if daemonID := middleware.DaemonIDFromContext(r.Context()); daemonID != "" && rt.DaemonID.Valid && rt.DaemonID.String != daemonID {
|
||||
writeError(w, http.StatusNotFound, "runtime not found")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.DaemonHub.HandleWebSocket(w, r, daemonws.ClientIdentity{
|
||||
DaemonID: middleware.DaemonIDFromContext(r.Context()),
|
||||
UserID: requestUserID(r),
|
||||
WorkspaceID: middleware.DaemonWorkspaceIDFromContext(r.Context()),
|
||||
RuntimeIDs: runtimeIDs,
|
||||
ClientVersion: r.Header.Get("X-Client-Version"),
|
||||
})
|
||||
}
|
||||
|
||||
func parseRuntimeIDs(r *http.Request) []string {
|
||||
seen := map[string]struct{}{}
|
||||
var out []string
|
||||
add := func(raw string) {
|
||||
for _, part := range strings.Split(raw, ",") {
|
||||
id := strings.TrimSpace(part)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
for _, raw := range r.URL.Query()["runtime_id"] {
|
||||
add(raw)
|
||||
}
|
||||
for _, raw := range r.URL.Query()["runtime_ids"] {
|
||||
add(raw)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/multica-ai/multica/server/internal/analytics"
|
||||
"github.com/multica-ai/multica/server/internal/auth"
|
||||
"github.com/multica-ai/multica/server/internal/daemonws"
|
||||
"github.com/multica-ai/multica/server/internal/events"
|
||||
"github.com/multica-ai/multica/server/internal/middleware"
|
||||
"github.com/multica-ai/multica/server/internal/realtime"
|
||||
@@ -53,6 +54,7 @@ type Handler struct {
|
||||
DB dbExecutor
|
||||
TxStarter txStarter
|
||||
Hub *realtime.Hub
|
||||
DaemonHub *daemonws.Hub
|
||||
Bus *events.Bus
|
||||
TaskService *service.TaskService
|
||||
AutopilotService *service.AutopilotService
|
||||
@@ -67,7 +69,7 @@ type Handler struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
func New(queries *db.Queries, txStarter txStarter, hub *realtime.Hub, bus *events.Bus, emailService *service.EmailService, store storage.Storage, cfSigner *auth.CloudFrontSigner, analyticsClient analytics.Client, cfg Config) *Handler {
|
||||
func New(queries *db.Queries, txStarter txStarter, hub *realtime.Hub, bus *events.Bus, emailService *service.EmailService, store storage.Storage, cfSigner *auth.CloudFrontSigner, analyticsClient analytics.Client, cfg Config, daemonHubs ...*daemonws.Hub) *Handler {
|
||||
var executor dbExecutor
|
||||
if candidate, ok := txStarter.(dbExecutor); ok {
|
||||
executor = candidate
|
||||
@@ -77,12 +79,18 @@ func New(queries *db.Queries, txStarter txStarter, hub *realtime.Hub, bus *event
|
||||
analyticsClient = analytics.NoopClient{}
|
||||
}
|
||||
|
||||
taskSvc := service.NewTaskService(queries, txStarter, hub, bus)
|
||||
var daemonHub *daemonws.Hub
|
||||
if len(daemonHubs) > 0 {
|
||||
daemonHub = daemonHubs[0]
|
||||
}
|
||||
|
||||
taskSvc := service.NewTaskService(queries, txStarter, hub, bus, daemonHub)
|
||||
return &Handler{
|
||||
Queries: queries,
|
||||
DB: executor,
|
||||
TxStarter: txStarter,
|
||||
Hub: hub,
|
||||
DaemonHub: daemonHub,
|
||||
Bus: bus,
|
||||
TaskService: taskSvc,
|
||||
AutopilotService: service.NewAutopilotService(queries, txStarter, bus, taskSvc),
|
||||
|
||||
70
server/internal/metrics/daemonws.go
Normal file
70
server/internal/metrics/daemonws.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/multica-ai/multica/server/internal/daemonws"
|
||||
)
|
||||
|
||||
type DaemonWSCollector struct {
|
||||
metrics *daemonws.Metrics
|
||||
|
||||
connectsTotal *prometheus.Desc
|
||||
disconnectsTotal *prometheus.Desc
|
||||
activeConnections *prometheus.Desc
|
||||
slowEvictionsTotal *prometheus.Desc
|
||||
wakeupPublishedTotal *prometheus.Desc
|
||||
wakeupPublishErrors *prometheus.Desc
|
||||
wakeupReceivedTotal *prometheus.Desc
|
||||
wakeupDeliveredTotal *prometheus.Desc
|
||||
}
|
||||
|
||||
func NewDaemonWSCollector(m *daemonws.Metrics) *DaemonWSCollector {
|
||||
return &DaemonWSCollector{
|
||||
metrics: m,
|
||||
|
||||
connectsTotal: newDaemonWSDesc("connects_total", "Total daemon WebSocket connections opened."),
|
||||
disconnectsTotal: newDaemonWSDesc("disconnects_total", "Total daemon WebSocket connections closed."),
|
||||
activeConnections: newDaemonWSDesc("active_connections", "Current daemon WebSocket connections."),
|
||||
slowEvictionsTotal: newDaemonWSDesc("slow_evictions_total", "Total daemon WebSocket clients evicted for slow consumption."),
|
||||
wakeupPublishedTotal: newDaemonWSDesc("wakeup_published_total", "Total daemon wakeups published to the Redis relay."),
|
||||
wakeupPublishErrors: newDaemonWSDesc("wakeup_publish_errors_total", "Total daemon wakeup Redis publish errors."),
|
||||
wakeupReceivedTotal: newDaemonWSDesc("wakeup_received_total", "Total daemon wakeups received from the Redis relay."),
|
||||
wakeupDeliveredTotal: prometheus.NewDesc("multica_daemonws_wakeup_delivered_total", "Total daemon wakeup local delivery attempts.", []string{"result"}, nil),
|
||||
}
|
||||
}
|
||||
|
||||
func newDaemonWSDesc(name, help string) *prometheus.Desc {
|
||||
return prometheus.NewDesc("multica_daemonws_"+name, help, nil, nil)
|
||||
}
|
||||
|
||||
func (c *DaemonWSCollector) Describe(ch chan<- *prometheus.Desc) {
|
||||
for _, desc := range []*prometheus.Desc{
|
||||
c.connectsTotal,
|
||||
c.disconnectsTotal,
|
||||
c.activeConnections,
|
||||
c.slowEvictionsTotal,
|
||||
c.wakeupPublishedTotal,
|
||||
c.wakeupPublishErrors,
|
||||
c.wakeupReceivedTotal,
|
||||
c.wakeupDeliveredTotal,
|
||||
} {
|
||||
ch <- desc
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DaemonWSCollector) Collect(ch chan<- prometheus.Metric) {
|
||||
if c.metrics == nil {
|
||||
return
|
||||
}
|
||||
m := c.metrics
|
||||
ch <- prometheus.MustNewConstMetric(c.connectsTotal, prometheus.CounterValue, float64(m.ConnectsTotal.Load()))
|
||||
ch <- prometheus.MustNewConstMetric(c.disconnectsTotal, prometheus.CounterValue, float64(m.DisconnectsTotal.Load()))
|
||||
ch <- prometheus.MustNewConstMetric(c.activeConnections, prometheus.GaugeValue, float64(m.ActiveConnections.Load()))
|
||||
ch <- prometheus.MustNewConstMetric(c.slowEvictionsTotal, prometheus.CounterValue, float64(m.SlowEvictionsTotal.Load()))
|
||||
ch <- prometheus.MustNewConstMetric(c.wakeupPublishedTotal, prometheus.CounterValue, float64(m.WakeupPublishedTotal.Load()))
|
||||
ch <- prometheus.MustNewConstMetric(c.wakeupPublishErrors, prometheus.CounterValue, float64(m.WakeupPublishErrors.Load()))
|
||||
ch <- prometheus.MustNewConstMetric(c.wakeupReceivedTotal, prometheus.CounterValue, float64(m.WakeupReceivedTotal.Load()))
|
||||
ch <- prometheus.MustNewConstMetric(c.wakeupDeliveredTotal, prometheus.CounterValue, float64(m.WakeupDeliveredHit.Load()), "hit")
|
||||
ch <- prometheus.MustNewConstMetric(c.wakeupDeliveredTotal, prometheus.CounterValue, float64(m.WakeupDeliveredMiss.Load()), "miss")
|
||||
}
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
|
||||
"github.com/multica-ai/multica/server/internal/daemonws"
|
||||
"github.com/multica-ai/multica/server/internal/realtime"
|
||||
)
|
||||
|
||||
type RegistryOptions struct {
|
||||
Pool *pgxpool.Pool
|
||||
Realtime *realtime.Metrics
|
||||
DaemonWS *daemonws.Metrics
|
||||
Version string
|
||||
Commit string
|
||||
}
|
||||
@@ -43,6 +45,9 @@ func NewRegistry(opts RegistryOptions) *Registry {
|
||||
if opts.Realtime != nil {
|
||||
reg.MustRegister(NewRealtimeCollector(opts.Realtime))
|
||||
}
|
||||
if opts.DaemonWS != nil {
|
||||
reg.MustRegister(NewDaemonWSCollector(opts.DaemonWS))
|
||||
}
|
||||
|
||||
return &Registry{
|
||||
Gatherer: reg,
|
||||
|
||||
@@ -8,6 +8,9 @@ const (
|
||||
ScopeUser = "user"
|
||||
ScopeTask = "task"
|
||||
ScopeChat = "chat"
|
||||
// ScopeDaemonRuntime routes daemon wakeup frames through the Redis relay.
|
||||
// It is consumed by the daemon WebSocket hub, not by browser clients.
|
||||
ScopeDaemonRuntime = "daemon_runtime"
|
||||
)
|
||||
|
||||
// Broadcaster is the abstraction every realtime event producer should depend
|
||||
@@ -38,5 +41,10 @@ type Broadcaster interface {
|
||||
Broadcast(message []byte)
|
||||
}
|
||||
|
||||
// DaemonRuntimeDeliverer consumes daemon-runtime scoped relay frames.
|
||||
type DaemonRuntimeDeliverer interface {
|
||||
DeliverDaemonRuntime(scopeID string, frame []byte, eventID string)
|
||||
}
|
||||
|
||||
// Compile-time assertion that *Hub continues to satisfy Broadcaster.
|
||||
var _ Broadcaster = (*Hub)(nil)
|
||||
|
||||
@@ -106,12 +106,16 @@ func redisString(v any) string {
|
||||
}
|
||||
}
|
||||
|
||||
func deliverEnvelope(hub *Hub, ev envelope) {
|
||||
func deliverEnvelope(hub *Hub, daemonRuntime DaemonRuntimeDeliverer, ev envelope) {
|
||||
if ev.PayloadJSON == "" {
|
||||
return
|
||||
}
|
||||
frame := injectEventID([]byte(ev.PayloadJSON), ev.EventID)
|
||||
switch ev.Scope {
|
||||
case ScopeDaemonRuntime:
|
||||
if daemonRuntime != nil {
|
||||
daemonRuntime.DeliverDaemonRuntime(ev.ScopeID, frame, ev.EventID)
|
||||
}
|
||||
case "global":
|
||||
hub.fanoutAllDedup(frame, "", ev.EventID)
|
||||
case ScopeUser:
|
||||
@@ -134,6 +138,8 @@ type RedisRelay struct {
|
||||
consumers map[scopeKey]*scopeConsumer
|
||||
stopping bool
|
||||
wg sync.WaitGroup
|
||||
|
||||
daemonRuntime DaemonRuntimeDeliverer
|
||||
}
|
||||
|
||||
type scopeConsumer struct {
|
||||
@@ -167,6 +173,10 @@ func NewRedisRelayWithClients(hub *Hub, writeRDB, readRDB *redis.Client) *RedisR
|
||||
// NodeID returns this relay's randomly-assigned node identifier.
|
||||
func (r *RedisRelay) NodeID() string { return r.nodeID }
|
||||
|
||||
func (r *RedisRelay) SetDaemonRuntimeDeliverer(d DaemonRuntimeDeliverer) {
|
||||
r.daemonRuntime = d
|
||||
}
|
||||
|
||||
// Wait blocks until all relay-owned goroutines have exited after the Start
|
||||
// context is canceled.
|
||||
func (r *RedisRelay) Wait() {
|
||||
@@ -394,7 +404,7 @@ func (r *RedisRelay) deliverMessage(scopeType, scopeID string, msg redis.XMessag
|
||||
if ev.ScopeID == "" {
|
||||
ev.ScopeID = scopeID
|
||||
}
|
||||
deliverEnvelope(r.hub, ev)
|
||||
deliverEnvelope(r.hub, r.daemonRuntime, ev)
|
||||
}
|
||||
|
||||
// fanoutUser is implemented in hub.go.
|
||||
|
||||
@@ -36,6 +36,15 @@ func (r *MirroredRelay) NodeID() string {
|
||||
return r.primary.NodeID()
|
||||
}
|
||||
|
||||
func (r *MirroredRelay) SetDaemonRuntimeDeliverer(d DaemonRuntimeDeliverer) {
|
||||
if setter, ok := r.primary.(interface{ SetDaemonRuntimeDeliverer(DaemonRuntimeDeliverer) }); ok {
|
||||
setter.SetDaemonRuntimeDeliverer(d)
|
||||
}
|
||||
if setter, ok := r.mirror.(interface{ SetDaemonRuntimeDeliverer(DaemonRuntimeDeliverer) }); ok {
|
||||
setter.SetDaemonRuntimeDeliverer(d)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MirroredRelay) Start(ctx context.Context) {
|
||||
r.primary.Start(ctx)
|
||||
r.mirror.Start(ctx)
|
||||
@@ -74,6 +83,9 @@ func (r *MirroredRelay) Broadcast(message []byte) {
|
||||
|
||||
func (r *MirroredRelay) PublishWithID(scopeType, scopeID, exclude string, frame []byte, id string) error {
|
||||
primaryErr := r.primary.PublishWithID(scopeType, scopeID, exclude, frame, id)
|
||||
if scopeType == ScopeDaemonRuntime {
|
||||
return primaryErr
|
||||
}
|
||||
mirrorErr := r.mirror.PublishWithID(scopeType, scopeID, exclude, frame, id)
|
||||
|
||||
if primaryErr != nil {
|
||||
|
||||
@@ -50,6 +50,23 @@ func TestMirroredRelayRecordsDivergenceWhenOneBackendFails(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMirroredRelayDoesNotMirrorDaemonRuntimeEvents(t *testing.T) {
|
||||
primary := &recordingManagedRelay{nodeID: "primary"}
|
||||
mirror := &recordingManagedRelay{nodeID: "mirror"}
|
||||
relay := NewMirroredRelay(primary, mirror)
|
||||
|
||||
if err := relay.PublishWithID(ScopeDaemonRuntime, "task-1", "", []byte(`{"type":"daemon:task_available"}`), "event-1"); err != nil {
|
||||
t.Fatalf("PublishWithID: %v", err)
|
||||
}
|
||||
|
||||
if len(primary.calls) != 1 {
|
||||
t.Fatalf("expected primary publish call, got %d", len(primary.calls))
|
||||
}
|
||||
if len(mirror.calls) != 0 {
|
||||
t.Fatalf("expected daemon runtime event not to hit mirror, got %d calls", len(mirror.calls))
|
||||
}
|
||||
}
|
||||
|
||||
type relayPublishCall struct {
|
||||
scopeType string
|
||||
scopeID string
|
||||
|
||||
@@ -76,6 +76,8 @@ type ShardedStreamRelay struct {
|
||||
mu sync.Mutex
|
||||
stopping bool
|
||||
wg sync.WaitGroup
|
||||
|
||||
daemonRuntime DaemonRuntimeDeliverer
|
||||
}
|
||||
|
||||
func NewShardedStreamRelay(hub *Hub, writeRDB, readRDB *redis.Client, config ShardedStreamRelayConfig) *ShardedStreamRelay {
|
||||
@@ -93,6 +95,10 @@ func NewShardedStreamRelay(hub *Hub, writeRDB, readRDB *redis.Client, config Sha
|
||||
|
||||
func (r *ShardedStreamRelay) NodeID() string { return r.nodeID }
|
||||
|
||||
func (r *ShardedStreamRelay) SetDaemonRuntimeDeliverer(d DaemonRuntimeDeliverer) {
|
||||
r.daemonRuntime = d
|
||||
}
|
||||
|
||||
func (r *ShardedStreamRelay) Start(ctx context.Context) {
|
||||
M.NodeID.Store(r.nodeID)
|
||||
if err := r.writeRDB.Ping(ctx).Err(); err != nil {
|
||||
@@ -233,7 +239,7 @@ func (r *ShardedStreamRelay) deliverMessage(msg redis.XMessage) {
|
||||
if !ok || ev.Scope == "" || ev.ScopeID == "" {
|
||||
return
|
||||
}
|
||||
deliverEnvelope(r.hub, ev)
|
||||
deliverEnvelope(r.hub, r.daemonRuntime, ev)
|
||||
}
|
||||
|
||||
func (r *ShardedStreamRelay) heartbeatLoop(ctx context.Context) {
|
||||
|
||||
@@ -26,10 +26,19 @@ type TaskService struct {
|
||||
TxStarter TxStarter
|
||||
Hub *realtime.Hub
|
||||
Bus *events.Bus
|
||||
Wakeup TaskWakeupNotifier
|
||||
}
|
||||
|
||||
func NewTaskService(q *db.Queries, tx TxStarter, hub *realtime.Hub, bus *events.Bus) *TaskService {
|
||||
return &TaskService{Queries: q, TxStarter: tx, Hub: hub, Bus: bus}
|
||||
type TaskWakeupNotifier interface {
|
||||
NotifyTaskAvailable(runtimeID, taskID string)
|
||||
}
|
||||
|
||||
func NewTaskService(q *db.Queries, tx TxStarter, hub *realtime.Hub, bus *events.Bus, wakeups ...TaskWakeupNotifier) *TaskService {
|
||||
var wakeup TaskWakeupNotifier
|
||||
if len(wakeups) > 0 {
|
||||
wakeup = wakeups[0]
|
||||
}
|
||||
return &TaskService{Queries: q, TxStarter: tx, Hub: hub, Bus: bus, Wakeup: wakeup}
|
||||
}
|
||||
|
||||
// EnqueueTaskForIssue creates a queued task for an agent-assigned issue.
|
||||
@@ -73,6 +82,7 @@ func (s *TaskService) EnqueueTaskForIssue(ctx context.Context, issue db.Issue, t
|
||||
}
|
||||
|
||||
slog.Info("task enqueued", "task_id", util.UUIDToString(task.ID), "issue_id", util.UUIDToString(issue.ID), "agent_id", util.UUIDToString(issue.AssigneeID))
|
||||
s.notifyTaskAvailable(task)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
@@ -107,6 +117,7 @@ func (s *TaskService) EnqueueTaskForMention(ctx context.Context, issue db.Issue,
|
||||
}
|
||||
|
||||
slog.Info("mention task enqueued", "task_id", util.UUIDToString(task.ID), "issue_id", util.UUIDToString(issue.ID), "agent_id", util.UUIDToString(agentID))
|
||||
s.notifyTaskAvailable(task)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
@@ -137,6 +148,7 @@ func (s *TaskService) EnqueueChatTask(ctx context.Context, chatSession db.ChatSe
|
||||
}
|
||||
|
||||
slog.Info("chat task enqueued", "task_id", util.UUIDToString(task.ID), "chat_session_id", util.UUIDToString(chatSession.ID), "agent_id", util.UUIDToString(chatSession.AgentID))
|
||||
s.notifyTaskAvailable(task)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
@@ -645,6 +657,7 @@ func (s *TaskService) MaybeRetryFailedTask(ctx context.Context, parent db.AgentT
|
||||
"attempt", child.Attempt,
|
||||
"max_attempts", child.MaxAttempts,
|
||||
)
|
||||
s.notifyTaskAvailable(child)
|
||||
s.broadcastTaskEvent(ctx, protocol.EventTaskDispatch, child)
|
||||
return &child, nil
|
||||
}
|
||||
@@ -902,6 +915,13 @@ func priorityToInt(p string) int32 {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TaskService) notifyTaskAvailable(task db.AgentTaskQueue) {
|
||||
if s.Wakeup == nil || !task.RuntimeID.Valid {
|
||||
return
|
||||
}
|
||||
s.Wakeup.NotifyTaskAvailable(util.UUIDToString(task.RuntimeID), util.UUIDToString(task.ID))
|
||||
}
|
||||
|
||||
func (s *TaskService) broadcastTaskDispatch(ctx context.Context, task db.AgentTaskQueue) {
|
||||
var payload map[string]any
|
||||
if task.Context != nil {
|
||||
|
||||
@@ -11,10 +11,10 @@ const (
|
||||
EventCommentCreated = "comment:created"
|
||||
EventCommentUpdated = "comment:updated"
|
||||
EventCommentDeleted = "comment:deleted"
|
||||
EventReactionAdded = "reaction:added"
|
||||
EventReactionRemoved = "reaction:removed"
|
||||
EventIssueReactionAdded = "issue_reaction:added"
|
||||
EventIssueReactionRemoved = "issue_reaction:removed"
|
||||
EventReactionAdded = "reaction:added"
|
||||
EventReactionRemoved = "reaction:removed"
|
||||
EventIssueReactionAdded = "issue_reaction:added"
|
||||
EventIssueReactionRemoved = "issue_reaction:removed"
|
||||
|
||||
// Agent events
|
||||
EventAgentStatus = "agent:status"
|
||||
@@ -93,6 +93,7 @@ const (
|
||||
EventAutopilotRunDone = "autopilot:run_done"
|
||||
|
||||
// Daemon events
|
||||
EventDaemonHeartbeat = "daemon:heartbeat"
|
||||
EventDaemonRegister = "daemon:register"
|
||||
EventDaemonHeartbeat = "daemon:heartbeat"
|
||||
EventDaemonRegister = "daemon:register"
|
||||
EventDaemonTaskAvailable = "daemon:task_available"
|
||||
)
|
||||
|
||||
@@ -16,6 +16,13 @@ type TaskDispatchPayload struct {
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// TaskAvailablePayload is sent from server to daemon as a wakeup hint. The
|
||||
// daemon still claims work through the existing HTTP claim endpoint.
|
||||
type TaskAvailablePayload struct {
|
||||
RuntimeID string `json:"runtime_id"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
}
|
||||
|
||||
// TaskProgressPayload is sent from daemon to server during task execution.
|
||||
type TaskProgressPayload struct {
|
||||
TaskID string `json:"task_id"`
|
||||
@@ -37,10 +44,10 @@ type TaskMessagePayload struct {
|
||||
IssueID string `json:"issue_id,omitempty"`
|
||||
Seq int `json:"seq"`
|
||||
Type string `json:"type"` // "text", "tool_use", "tool_result", "error"
|
||||
Tool string `json:"tool,omitempty"` // tool name for tool_use/tool_result
|
||||
Content string `json:"content,omitempty"` // text content
|
||||
Input map[string]any `json:"input,omitempty"` // tool input (tool_use only)
|
||||
Output string `json:"output,omitempty"` // tool output (tool_result only)
|
||||
Tool string `json:"tool,omitempty"` // tool name for tool_use/tool_result
|
||||
Content string `json:"content,omitempty"` // text content
|
||||
Input map[string]any `json:"input,omitempty"` // tool input (tool_use only)
|
||||
Output string `json:"output,omitempty"` // tool output (tool_result only)
|
||||
}
|
||||
|
||||
// DaemonRegisterPayload is sent from daemon to server on connection.
|
||||
|
||||
Reference in New Issue
Block a user