From cd414a52ea489fd89707f8148156617cf5cb847b Mon Sep 17 00:00:00 2001 From: Jiayuan Zhang Date: Sat, 16 May 2026 18:32:00 +0800 Subject: [PATCH] =?UTF-8?q?feat(cli):=20multica=20issue=20terminal=20?= =?UTF-8?q?=E2=80=94=20attach=20via=20Phase=202=20WS=20endpoint=20(MUL-229?= =?UTF-8?q?5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 3 of MUL-2295. Adds `multica issue terminal ` which dials the Phase 2 /ws/issues/{id}/terminal endpoint, performs first-frame auth with the existing PAT/JWT, and runs an interactive PTY through the daemon-side terminal manager from Phase 1. SIGWINCH on unix / poll on windows pushes resize frames; ssh-style `~.` detaches. Co-authored-by: multica-agent --- server/cmd/multica/cmd_issue_terminal.go | 563 ++++++++++++++++++ server/cmd/multica/cmd_issue_terminal_test.go | 526 ++++++++++++++++ server/cmd/multica/cmd_issue_terminal_unix.go | 40 ++ .../cmd/multica/cmd_issue_terminal_windows.go | 39 ++ server/go.mod | 3 +- server/go.sum | 6 +- server/internal/cli/ws.go | 97 +++ server/internal/cli/ws_test.go | 125 ++++ 8 files changed, 1396 insertions(+), 3 deletions(-) create mode 100644 server/cmd/multica/cmd_issue_terminal.go create mode 100644 server/cmd/multica/cmd_issue_terminal_test.go create mode 100644 server/cmd/multica/cmd_issue_terminal_unix.go create mode 100644 server/cmd/multica/cmd_issue_terminal_windows.go create mode 100644 server/internal/cli/ws.go create mode 100644 server/internal/cli/ws_test.go diff --git a/server/cmd/multica/cmd_issue_terminal.go b/server/cmd/multica/cmd_issue_terminal.go new file mode 100644 index 000000000..830b0b848 --- /dev/null +++ b/server/cmd/multica/cmd_issue_terminal.go @@ -0,0 +1,563 @@ +package main + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/url" + "os" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + "github.com/spf13/cobra" + "golang.org/x/term" + + "github.com/multica-ai/multica/server/pkg/protocol" +) + +var issueTerminalCmd = &cobra.Command{ + Use: "terminal ", + Short: "Attach to the issue's most recent agent task PTY", + Long: "Open an interactive shell inside the workdir of the issue's most recent agent task. " + + "Reuses the daemon-side PTY manager added in MUL-2295 — the daemon spawns a bash login " + + "shell with CLAUDE_SESSION_ID + MULTICA_{WORKSPACE,ISSUE,TASK,USER}_ID injected so you " + + "can immediately `claude --resume $CLAUDE_SESSION_ID`.\n\n" + + "Detach without closing your shell: type `~.` (escape sequence). The daemon-side " + + "session is currently torn down on disconnect — see RFC follow-up for `--attach`.", + Args: exactArgs(1), + RunE: runIssueTerminal, +} + +const ( + terminalDefaultCols = 80 + terminalDefaultRows = 24 + terminalAuthAckTimeout = 10 * time.Second + terminalOpenAckTimeout = 15 * time.Second + terminalServerWriteWait = 10 * time.Second + terminalServerReadLimit = 1 << 20 // 1 MiB per frame; matches realistic xterm bursts + terminalDetachExitMessage = "[multica] detached — daemon session was torn down" +) + +func init() { + issueCmd.AddCommand(issueTerminalCmd) + issueTerminalCmd.Flags().Uint16("cols", 0, "Initial terminal columns (defaults to detected size, or 80 if stdout is not a TTY)") + issueTerminalCmd.Flags().Uint16("rows", 0, "Initial terminal rows (defaults to detected size, or 24 if stdout is not a TTY)") + issueTerminalCmd.Flags().String("escape-char", "~", "Escape character for detach sequence (`.` to detach). Empty disables escape detection.") + issueTerminalCmd.Flags().Bool("no-raw", false, "Don't put the local TTY into raw mode (mostly for testing / piped input)") +} + +func runIssueTerminal(cmd *cobra.Command, args []string) error { + client, err := newAPIClient(cmd) + if err != nil { + return err + } + if _, err := requireWorkspaceID(cmd); err != nil { + return err + } + token := resolveToken(cmd) + if token == "" { + return fmt.Errorf("not authenticated: run 'multica login'") + } + + resolveCtx, cancelResolve := context.WithTimeout(cmd.Context(), 15*time.Second) + defer cancelResolve() + issueRef, err := resolveIssueRef(resolveCtx, client, args[0]) + if err != nil { + return fmt.Errorf("resolve issue: %w", err) + } + + // Detect terminal size from stdout (the surface the user actually sees); + // fall back to defaults if stdout is piped. Flag overrides win. + cols, rows := detectInitialSize(cmd) + + pathAndQuery := buildTerminalPathAndQuery(issueRef.ID, client.WorkspaceID, cols, rows) + + // Use a long-lived context for the WS connection; cancellation is driven + // by the proxy goroutines + signals rather than a timeout. + conn, _, err := client.DialWebSocket(cmd.Context(), pathAndQuery) + if err != nil { + return fmt.Errorf("dial terminal websocket: %w", err) + } + + proxy := newCLITerminalProxy(conn, os.Stdin, os.Stdout, os.Stderr, token, cmd) + return proxy.run(cmd.Context(), cols, rows) +} + +func detectInitialSize(cmd *cobra.Command) (uint16, uint16) { + cols, _ := cmd.Flags().GetUint16("cols") + rows, _ := cmd.Flags().GetUint16("rows") + if cols > 0 && rows > 0 { + return cols, rows + } + if c, r, err := term.GetSize(int(os.Stdout.Fd())); err == nil && c > 0 && r > 0 { + if cols == 0 { + cols = uint16(c) + } + if rows == 0 { + rows = uint16(r) + } + } + if cols == 0 { + cols = terminalDefaultCols + } + if rows == 0 { + rows = terminalDefaultRows + } + return cols, rows +} + +func buildTerminalPathAndQuery(issueID, workspaceID string, cols, rows uint16) string { + q := url.Values{} + q.Set("workspace_id", workspaceID) + q.Set("cols", strconv.FormatUint(uint64(cols), 10)) + q.Set("rows", strconv.FormatUint(uint64(rows), 10)) + return "/ws/issues/" + url.PathEscape(issueID) + "/terminal?" + q.Encode() +} + +// cliTerminalProxy mirrors the server-side terminalProxy: one goroutine +// owns conn writes, one owns conn reads, plus a stdin reader and resize +// watcher. The struct is the only owner of the websocket.Conn; all writes +// go through writeFrame() to keep a single point that holds writeMu. +type cliTerminalProxy struct { + conn *websocket.Conn + stdin io.Reader + stdout io.Writer + stderr io.Writer + token string + cmd *cobra.Command + + writeMu sync.Mutex + + sessionMu sync.RWMutex + sessionID string + + closeOnce sync.Once + doneCh chan struct{} + + // exit reporting from the read pump back to the orchestrator. + exitCode atomic.Int32 // 0 = unset, see exitCodeUnset / >=1 + exitMsg atomic.Pointer[string] + + escapeChar byte + noRaw bool +} + +const exitCodeUnset int32 = -1 + +func newCLITerminalProxy(conn *websocket.Conn, stdin io.Reader, stdout, stderr io.Writer, token string, cmd *cobra.Command) *cliTerminalProxy { + escape, _ := cmd.Flags().GetString("escape-char") + noRaw, _ := cmd.Flags().GetBool("no-raw") + var ec byte + if len(escape) >= 1 { + ec = escape[0] + } + p := &cliTerminalProxy{ + conn: conn, + stdin: stdin, + stdout: stdout, + stderr: stderr, + token: token, + cmd: cmd, + doneCh: make(chan struct{}), + escapeChar: ec, + noRaw: noRaw, + } + p.exitCode.Store(exitCodeUnset) + conn.SetReadLimit(terminalServerReadLimit) + return p +} + +func (p *cliTerminalProxy) run(ctx context.Context, cols, rows uint16) error { + defer p.conn.Close() + + if err := p.handshake(); err != nil { + return err + } + + // Push our local size right after open in case the server's hardcoded + // initial 80x24 didn't match. (Phase 2 server stamps 80x24 on the + // daemon-bound terminal.open frame regardless of query string; sending + // resize immediately makes the PTY render correctly.) + if err := p.sendResize(cols, rows); err != nil { + // non-fatal — daemon will just keep the original size + fmt.Fprintf(p.stderr, "[multica] warning: initial resize failed: %v\n", err) + } + + rawTTY := !p.noRaw && term.IsTerminal(int(os.Stdin.Fd())) + var restore func() error + if rawTTY { + oldState, err := term.MakeRaw(int(os.Stdin.Fd())) + if err != nil { + return fmt.Errorf("enter raw mode: %w", err) + } + fd := int(os.Stdin.Fd()) + restore = func() error { return term.Restore(fd, oldState) } + defer restore() + } + + stopResize := startResizeWatcher(p) + defer stopResize() + + go p.readPump() + go p.stdinPump(rawTTY) + + select { + case <-p.doneCh: + case <-ctx.Done(): + p.shutdown() + } + + if restore != nil { + _ = restore() + } + + if msgPtr := p.exitMsg.Load(); msgPtr != nil && *msgPtr != "" { + fmt.Fprintln(p.stderr, *msgPtr) + } + if code := p.exitCode.Load(); code > 0 { + os.Exit(int(code)) + } + return nil +} + +// handshake performs first-frame auth and waits for terminal.opened. +func (p *cliTerminalProxy) handshake() error { + authFrame, err := json.Marshal(struct { + Type string `json:"type"` + Payload map[string]any `json:"payload"` + }{ + Type: "auth", + Payload: map[string]any{"token": p.token}, + }) + if err != nil { + return fmt.Errorf("marshal auth frame: %w", err) + } + if err := p.writeRawFrame(authFrame); err != nil { + return fmt.Errorf("send auth frame: %w", err) + } + + deadline := time.Now().Add(terminalAuthAckTimeout) + if err := p.conn.SetReadDeadline(deadline); err != nil { + return fmt.Errorf("set auth read deadline: %w", err) + } + for { + _, raw, err := p.conn.ReadMessage() + if err != nil { + return fmt.Errorf("read auth response: %w", err) + } + var preview struct { + Type string `json:"type"` + Error string `json:"error"` + } + if err := json.Unmarshal(raw, &preview); err == nil { + if preview.Error != "" { + return fmt.Errorf("auth rejected: %s", preview.Error) + } + if preview.Type == "auth_ack" { + break + } + } + // Tolerate stray frames during handshake (none expected in current + // server implementation, but don't lock up if that changes). + } + + // After auth_ack the server proxies a terminal.open to the daemon and + // waits for terminal.opened or terminal.error. Block until we see one. + openDeadline := time.Now().Add(terminalOpenAckTimeout) + if err := p.conn.SetReadDeadline(openDeadline); err != nil { + return fmt.Errorf("set open read deadline: %w", err) + } + for { + _, raw, err := p.conn.ReadMessage() + if err != nil { + return fmt.Errorf("waiting for terminal.opened: %w", err) + } + var env protocol.Message + if err := json.Unmarshal(raw, &env); err != nil { + continue + } + switch env.Type { + case protocol.MessageTypeTerminalOpened: + var op protocol.TerminalOpenedPayload + if err := json.Unmarshal(env.Payload, &op); err != nil { + return fmt.Errorf("decode terminal.opened: %w", err) + } + if op.SessionID == "" { + return fmt.Errorf("daemon returned empty session_id in terminal.opened") + } + p.setSessionID(op.SessionID) + workDir := op.WorkDir + if workDir == "" { + workDir = "(unknown)" + } + fmt.Fprintf(p.stderr, "[multica] attached to %s — escape: %s.\r\n", workDir, escapeHelpString(p.escapeChar)) + // Restore non-blocking reads for the pumps. + if err := p.conn.SetReadDeadline(time.Time{}); err != nil { + return fmt.Errorf("clear read deadline: %w", err) + } + return nil + case protocol.MessageTypeTerminalError: + var ep protocol.TerminalErrorPayload + if err := json.Unmarshal(env.Payload, &ep); err != nil { + return fmt.Errorf("daemon returned terminal.error (undecodable)") + } + return fmt.Errorf("daemon rejected terminal.open: %s (%s)", ep.Message, ep.Code) + default: + // keep waiting + } + } +} + +func escapeHelpString(b byte) string { + if b == 0 { + return "(disabled)" + } + return "" + string(b) + "." +} + +func (p *cliTerminalProxy) readPump() { + defer p.shutdown() + for { + _, raw, err := p.conn.ReadMessage() + if err != nil { + if !isClosedConnError(err) { + msg := fmt.Sprintf("[multica] websocket closed: %v", err) + p.exitMsg.CompareAndSwap(nil, &msg) + } + return + } + var env protocol.Message + if err := json.Unmarshal(raw, &env); err != nil { + continue + } + switch env.Type { + case protocol.MessageTypeTerminalData: + var pl protocol.TerminalDataPayload + if err := json.Unmarshal(env.Payload, &pl); err != nil { + continue + } + data, err := base64.StdEncoding.DecodeString(pl.DataB64) + if err != nil { + continue + } + _, _ = p.stdout.Write(data) + case protocol.MessageTypeTerminalExit: + var pl protocol.TerminalExitPayload + if err := json.Unmarshal(env.Payload, &pl); err != nil { + continue + } + reason := pl.Reason + if reason == "" { + reason = "child exited" + } + msg := fmt.Sprintf("\r\n[multica] %s (exit code %d)", reason, pl.ExitCode) + p.exitMsg.CompareAndSwap(nil, &msg) + if pl.ExitCode > 0 { + p.exitCode.Store(int32(pl.ExitCode)) + } + return + case protocol.MessageTypeTerminalError: + var pl protocol.TerminalErrorPayload + if err := json.Unmarshal(env.Payload, &pl); err != nil { + continue + } + msg := fmt.Sprintf("\r\n[multica] error: %s (%s)", pl.Message, pl.Code) + p.exitMsg.CompareAndSwap(nil, &msg) + p.exitCode.Store(1) + return + case protocol.MessageTypeTerminalClose: + return + } + } +} + +// stdinPump reads stdin, runs it through the escape-sequence state machine, +// and forwards bytes as terminal.data frames. Detach (~.) closes the WS +// without sending the bytes. +func (p *cliTerminalProxy) stdinPump(rawTTY bool) { + defer p.shutdown() + + buf := make([]byte, 4096) + // Start in newline state so the very first character can trigger an + // escape sequence; mirrors ssh's behavior. + state := newlineState{atNewline: true} + for { + n, err := p.stdin.Read(buf) + if n > 0 { + toSend, detach := state.process(buf[:n], p.escapeChar) + if len(toSend) > 0 { + if err := p.sendData(toSend); err != nil { + return + } + } + if detach { + msg := terminalDetachExitMessage + p.exitMsg.CompareAndSwap(nil, &msg) + _ = p.sendCloseBestEffort("client_detach") + return + } + } + if err != nil { + if !errors.Is(err, io.EOF) { + msg := fmt.Sprintf("[multica] stdin error: %v", err) + p.exitMsg.CompareAndSwap(nil, &msg) + } + return + } + } +} + +func (p *cliTerminalProxy) sendData(data []byte) error { + sid := p.SessionID() + if sid == "" { + return errors.New("session_id not set") + } + frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalData, protocol.TerminalDataPayload{ + SessionID: sid, + DataB64: base64.StdEncoding.EncodeToString(data), + }) + if err != nil { + return err + } + return p.writeRawFrame(frame) +} + +func (p *cliTerminalProxy) sendResize(cols, rows uint16) error { + sid := p.SessionID() + if sid == "" { + // Pre-handshake resize is sent later by run() once session is known. + return nil + } + frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalResize, protocol.TerminalResizePayload{ + SessionID: sid, + Cols: cols, + Rows: rows, + }) + if err != nil { + return err + } + return p.writeRawFrame(frame) +} + +func (p *cliTerminalProxy) sendCloseBestEffort(reason string) error { + sid := p.SessionID() + if sid == "" { + return nil + } + frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalClose, protocol.TerminalClosePayload{ + SessionID: sid, + Reason: reason, + }) + if err != nil { + return err + } + return p.writeRawFrame(frame) +} + +func (p *cliTerminalProxy) writeRawFrame(frame []byte) error { + p.writeMu.Lock() + defer p.writeMu.Unlock() + if err := p.conn.SetWriteDeadline(time.Now().Add(terminalServerWriteWait)); err != nil { + return err + } + return p.conn.WriteMessage(websocket.TextMessage, frame) +} + +func (p *cliTerminalProxy) SessionID() string { + p.sessionMu.RLock() + defer p.sessionMu.RUnlock() + return p.sessionID +} + +func (p *cliTerminalProxy) setSessionID(sid string) { + p.sessionMu.Lock() + defer p.sessionMu.Unlock() + p.sessionID = sid +} + +func (p *cliTerminalProxy) shutdown() { + p.closeOnce.Do(func() { + close(p.doneCh) + _ = p.conn.Close() + }) +} + +func marshalCLITerminalFrame(msgType string, payload any) ([]byte, error) { + raw, err := json.Marshal(payload) + if err != nil { + return nil, err + } + return json.Marshal(protocol.Message{Type: msgType, Payload: raw}) +} + +func isClosedConnError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) { + return true + } + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + return true + } + return false +} + +// --- escape sequence state machine ----------------------------------------- +// +// Mirrors ssh(1)'s `~.` detach: after a newline, a single escape character +// followed by `.` detaches; `~~` emits a literal escape; `~?` prints help; +// any other byte aborts the escape and forwards both bytes. + +type newlineState struct { + atNewline bool + gotEscape bool +} + +// process consumes a chunk of stdin bytes. Returns the bytes that should +// actually be forwarded to the daemon and whether the user requested detach. +// The state machine mutates the receiver across calls so multi-byte chunks +// straddling escape boundaries (rare, but possible with paste) work. +func (s *newlineState) process(in []byte, escape byte) (out []byte, detach bool) { + if escape == 0 { + // Escape detection disabled — pass through. + return in, false + } + out = make([]byte, 0, len(in)) + for _, b := range in { + switch { + case s.gotEscape: + s.gotEscape = false + switch b { + case '.': + return out, true + case escape: + out = append(out, escape) + s.atNewline = false + case '?': + // Help is a local-only signal — not delivered to PTY. + // Caller can detect by … actually keep it simple: just + // emit a CR for visual feedback so the prompt redraws. + out = append(out, '\r') + s.atNewline = true + default: + // Not a recognized escape: forward ESC then this byte. + out = append(out, escape, b) + s.atNewline = b == '\r' || b == '\n' + } + case s.atNewline && b == escape: + s.gotEscape = true + default: + out = append(out, b) + s.atNewline = b == '\r' || b == '\n' + } + } + return out, false +} + diff --git a/server/cmd/multica/cmd_issue_terminal_test.go b/server/cmd/multica/cmd_issue_terminal_test.go new file mode 100644 index 000000000..9d91bef38 --- /dev/null +++ b/server/cmd/multica/cmd_issue_terminal_test.go @@ -0,0 +1,526 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/spf13/cobra" + + "github.com/multica-ai/multica/server/pkg/protocol" +) + +func TestEscapeState_DetachOnFreshLine(t *testing.T) { + s := &newlineState{atNewline: true} + out, detach := s.process([]byte("~."), '~') + if !detach { + t.Fatalf("expected detach") + } + if len(out) != 0 { + t.Fatalf("expected no bytes forwarded, got %q", out) + } +} + +func TestEscapeState_TildeNotAfterNewlineIsLiteral(t *testing.T) { + s := &newlineState{atNewline: false} + out, detach := s.process([]byte("foo~.bar"), '~') + if detach { + t.Fatalf("must not detach when ~ is mid-line") + } + if string(out) != "foo~.bar" { + t.Fatalf("got %q", out) + } +} + +func TestEscapeState_DoubleTildeEmitsLiteral(t *testing.T) { + s := &newlineState{atNewline: true} + out, detach := s.process([]byte("~~"), '~') + if detach { + t.Fatalf("~~ must not detach") + } + if string(out) != "~" { + t.Fatalf("got %q want ~", out) + } +} + +func TestEscapeState_StraddledChunks(t *testing.T) { + // User pastes/types ~ and . in two separate stdin reads — escape + // detection still works because state is preserved across calls. + s := &newlineState{atNewline: true} + out1, detach1 := s.process([]byte("~"), '~') + if detach1 || len(out1) != 0 { + t.Fatalf("first chunk: detach=%v out=%q", detach1, out1) + } + out2, detach2 := s.process([]byte("."), '~') + if !detach2 { + t.Fatalf("expected detach on second chunk") + } + if len(out2) != 0 { + t.Fatalf("second chunk should forward nothing, got %q", out2) + } +} + +func TestEscapeState_DisabledWhenEscapeIsZero(t *testing.T) { + s := &newlineState{atNewline: true} + out, detach := s.process([]byte("~."), 0) + if detach { + t.Fatalf("disabled escape must not detach") + } + if string(out) != "~." { + t.Fatalf("got %q want ~.", out) + } +} + +func TestEscapeState_UnknownEscapeForwardsBoth(t *testing.T) { + s := &newlineState{atNewline: true} + out, _ := s.process([]byte("~x"), '~') + if string(out) != "~x" { + t.Fatalf("got %q want ~x", out) + } +} + +func TestBuildTerminalPathAndQuery(t *testing.T) { + got := buildTerminalPathAndQuery("MUL-2295", "ws-uuid", 120, 40) + u, err := url.Parse("http://x" + got) + if err != nil { + t.Fatalf("parse: %v", err) + } + if u.Path != "/ws/issues/MUL-2295/terminal" { + t.Errorf("path = %q", u.Path) + } + q := u.Query() + if q.Get("workspace_id") != "ws-uuid" { + t.Errorf("workspace_id = %q", q.Get("workspace_id")) + } + if q.Get("cols") != "120" { + t.Errorf("cols = %q", q.Get("cols")) + } + if q.Get("rows") != "40" { + t.Errorf("rows = %q", q.Get("rows")) + } +} + +// fakeServer simulates the Phase 2 /ws/issues/{id}/terminal handshake plus +// a tiny echo loop, so we can drive the CLI proxy through its full lifecycle +// in-process without spinning up the real daemon. +type fakeServer struct { + t *testing.T + upgrader websocket.Upgrader + gotAuth chan string + gotData chan []byte + gotClose chan string + sessionID string + server *httptest.Server + connMu sync.Mutex + conn *websocket.Conn + sendOpenErr *protocol.TerminalErrorPayload // if set, send terminal.error instead of terminal.opened +} + +// writeFrame serializes writes from the handler goroutine and any test +// goroutine that wants to push a frame to the connected client. Required +// because gorilla/websocket allows concurrent read+write but NOT concurrent +// writes from different goroutines. +func (fs *fakeServer) writeFrame(frame []byte) error { + fs.connMu.Lock() + defer fs.connMu.Unlock() + if fs.conn == nil { + return fmt.Errorf("no client") + } + return fs.conn.WriteMessage(websocket.TextMessage, frame) +} + +func newFakeServer(t *testing.T) *fakeServer { + fs := &fakeServer{ + t: t, + upgrader: websocket.Upgrader{}, + gotAuth: make(chan string, 1), + gotData: make(chan []byte, 32), + gotClose: make(chan string, 1), + sessionID: "session-xyz", + } + fs.server = httptest.NewServer(http.HandlerFunc(fs.handle)) + return fs +} + +func (fs *fakeServer) close() { + fs.connMu.Lock() + c := fs.conn + fs.connMu.Unlock() + if c != nil { + c.Close() + } + fs.server.Close() +} + +func (fs *fakeServer) baseURL() string { return fs.server.URL } + +func (fs *fakeServer) handle(w http.ResponseWriter, r *http.Request) { + conn, err := fs.upgrader.Upgrade(w, r, nil) + if err != nil { + fs.t.Errorf("upgrade: %v", err) + return + } + fs.connMu.Lock() + fs.conn = conn + fs.connMu.Unlock() + + // 1. Auth. + _, raw, err := conn.ReadMessage() + if err != nil { + return + } + var auth struct { + Type string `json:"type"` + Payload map[string]any `json:"payload"` + } + if err := json.Unmarshal(raw, &auth); err != nil || auth.Type != "auth" { + _ = fs.writeFrame([]byte(`{"error":"bad auth"}`)) + return + } + tok, _ := auth.Payload["token"].(string) + fs.gotAuth <- tok + _ = fs.writeFrame([]byte(`{"type":"auth_ack"}`)) + + // 2. Open ack. + if fs.sendOpenErr != nil { + ep := *fs.sendOpenErr + frame, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalError, ep) + _ = fs.writeFrame(frame) + return + } + openedFrame, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalOpened, protocol.TerminalOpenedPayload{ + SessionID: fs.sessionID, + WorkDir: "/tmp/work", + Shell: "/bin/bash", + }) + _ = fs.writeFrame(openedFrame) + + // 3. Pump. + for { + _, raw, err := conn.ReadMessage() + if err != nil { + return + } + var env protocol.Message + if err := json.Unmarshal(raw, &env); err != nil { + continue + } + switch env.Type { + case protocol.MessageTypeTerminalData: + var pl protocol.TerminalDataPayload + if err := json.Unmarshal(env.Payload, &pl); err != nil { + continue + } + data, _ := base64.StdEncoding.DecodeString(pl.DataB64) + fs.gotData <- data + // Echo back so the CLI's stdout pump has something to do. + echo, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalData, protocol.TerminalDataPayload{ + SessionID: fs.sessionID, + DataB64: pl.DataB64, + }) + _ = fs.writeFrame(echo) + case protocol.MessageTypeTerminalClose: + var pl protocol.TerminalClosePayload + _ = json.Unmarshal(env.Payload, &pl) + fs.gotClose <- pl.Reason + return + case protocol.MessageTypeTerminalResize: + // observed but unused in this fake + } + } +} + +func newTestCmd() *cobra.Command { + c := &cobra.Command{} + c.Flags().String("escape-char", "~", "") + c.Flags().Bool("no-raw", true, "") + return c +} + +func TestCLITerminalProxy_HandshakeAndEcho(t *testing.T) { + fs := newFakeServer(t) + defer fs.close() + + wsURL := strings.Replace(fs.baseURL(), "http://", "ws://", 1) + "/" + dialer := *websocket.DefaultDialer + conn, _, err := dialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + + stdinR, stdinW := io.Pipe() + stdout := newSafeBuffer() + stderr := newSafeBuffer() + + cmd := newTestCmd() + p := newCLITerminalProxy(conn, stdinR, stdout, stderr, "mul_test", cmd) + + // Drive handshake explicitly so we can also assert the auth token reached + // the fake server. + if err := p.handshake(); err != nil { + t.Fatalf("handshake: %v", err) + } + select { + case got := <-fs.gotAuth: + if got != "mul_test" { + t.Errorf("auth token = %q, want mul_test", got) + } + case <-time.After(2 * time.Second): + t.Fatal("server did not receive auth frame") + } + if p.SessionID() != fs.sessionID { + t.Fatalf("session_id = %q, want %q", p.SessionID(), fs.sessionID) + } + + // Now run the pumps in a goroutine. + pumpsDone := make(chan struct{}) + go func() { + go p.readPump() + p.stdinPump(false) + close(pumpsDone) + }() + + // Send "hello" through stdin; expect server to receive it and echo it + // back into stdout. + if _, err := stdinW.Write([]byte("hello")); err != nil { + t.Fatalf("stdin write: %v", err) + } + + select { + case got := <-fs.gotData: + if string(got) != "hello" { + t.Fatalf("server got %q, want hello", got) + } + case <-time.After(2 * time.Second): + t.Fatal("server did not receive data") + } + + // Wait for the echo to land in stdout. + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if strings.Contains(stdout.String(), "hello") { + break + } + time.Sleep(10 * time.Millisecond) + } + if !strings.Contains(stdout.String(), "hello") { + t.Fatalf("stdout missing echo, got %q", stdout.String()) + } + + // Trigger detach: send "\n~." after a newline. Because stdinPump starts + // the state machine at atNewline=true on the very first byte, we need + // to walk through a real newline first to make the test realistic. + if _, err := stdinW.Write([]byte("\n~.")); err != nil { + t.Fatalf("stdin write detach: %v", err) + } + + select { + case <-pumpsDone: + case <-time.After(3 * time.Second): + t.Fatal("stdin pump did not exit after detach") + } + + select { + case reason := <-fs.gotClose: + if reason != "client_detach" { + t.Errorf("close reason = %q, want client_detach", reason) + } + case <-time.After(2 * time.Second): + t.Fatal("server did not receive terminal.close on detach") + } + + // run() prints the exit message to stderr; in this lower-level test we + // drive the pumps directly, so check the captured exit message. + msgPtr := p.exitMsg.Load() + if msgPtr == nil || !strings.Contains(*msgPtr, "detached") { + got := "" + if msgPtr != nil { + got = *msgPtr + } + t.Errorf("exit msg = %q, want detach text", got) + } +} + +// safeBuffer is a tiny mutex-wrapped bytes.Buffer for tests that read from +// the buffer in one goroutine while another writes (race-detector-clean). +type safeBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func newSafeBuffer() *safeBuffer { return &safeBuffer{} } + +func (b *safeBuffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.Write(p) +} + +func (b *safeBuffer) String() string { + b.mu.Lock() + defer b.mu.Unlock() + return b.buf.String() +} + +func TestCLITerminalProxy_HandshakeRejectedOnTerminalError(t *testing.T) { + fs := newFakeServer(t) + fs.sendOpenErr = &protocol.TerminalErrorPayload{ + Code: protocol.TerminalErrorCodeTaskNotFound, + Message: "no agent task on this issue", + } + defer fs.close() + + wsURL := strings.Replace(fs.baseURL(), "http://", "ws://", 1) + "/" + dialer := *websocket.DefaultDialer + conn, _, err := dialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + + cmd := newTestCmd() + p := newCLITerminalProxy(conn, strings.NewReader(""), io.Discard, io.Discard, "mul_test", cmd) + err = p.handshake() + if err == nil { + t.Fatal("expected handshake error, got nil") + } + if !strings.Contains(err.Error(), protocol.TerminalErrorCodeTaskNotFound) { + t.Errorf("error %q does not mention error code", err) + } +} + +func TestCLITerminalProxy_AuthRejected(t *testing.T) { + upgrader := websocket.Upgrader{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + // Read auth frame, reply with error. + _, _, _ = conn.ReadMessage() + _ = conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"invalid token"}`)) + })) + defer server.Close() + + wsURL := strings.Replace(server.URL, "http://", "ws://", 1) + "/" + dialer := *websocket.DefaultDialer + conn, _, err := dialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + + cmd := newTestCmd() + p := newCLITerminalProxy(conn, strings.NewReader(""), io.Discard, io.Discard, "mul_test", cmd) + err = p.handshake() + if err == nil { + t.Fatal("expected handshake error, got nil") + } + if !strings.Contains(err.Error(), "invalid token") { + t.Errorf("error %q does not surface server reason", err) + } +} + +func TestCLITerminalProxy_TerminalExitDeliversCode(t *testing.T) { + // Driver: open server, advance through handshake, then push a + // terminal.exit frame and verify the proxy's exit code state. + fs := newFakeServer(t) + defer fs.close() + + wsURL := strings.Replace(fs.baseURL(), "http://", "ws://", 1) + "/" + dialer := *websocket.DefaultDialer + conn, _, err := dialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + + cmd := newTestCmd() + p := newCLITerminalProxy(conn, strings.NewReader(""), io.Discard, io.Discard, "mul_test", cmd) + if err := p.handshake(); err != nil { + t.Fatalf("handshake: %v", err) + } + + exitFrame, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalExit, protocol.TerminalExitPayload{ + SessionID: fs.sessionID, + ExitCode: 42, + Reason: "child exited", + }) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + p.readPump() + }() + if err := fs.writeFrame(exitFrame); err != nil { + t.Fatalf("server write exit: %v", err) + } + + doneAt := time.Now().Add(2 * time.Second) + for time.Now().Before(doneAt) { + if p.exitCode.Load() == 42 { + break + } + time.Sleep(10 * time.Millisecond) + } + wg.Wait() + if got := p.exitCode.Load(); got != 42 { + t.Fatalf("exit code = %d, want 42", got) + } + msgPtr := p.exitMsg.Load() + if msgPtr == nil || !strings.Contains(*msgPtr, "exit code 42") { + got := "" + if msgPtr != nil { + got = *msgPtr + } + t.Errorf("exit msg = %q", got) + } +} + +// Compile-time check: ensure the marshaled frame round-trips through the +// real protocol.Message envelope. Catches any drift if the protocol pkg +// renames a field. +func TestMarshalCLITerminalFrame_EnvelopeShape(t *testing.T) { + frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalResize, protocol.TerminalResizePayload{ + SessionID: "sid", + Cols: 100, + Rows: 30, + }) + if err != nil { + t.Fatal(err) + } + var env protocol.Message + if err := json.Unmarshal(frame, &env); err != nil { + t.Fatal(err) + } + if env.Type != protocol.MessageTypeTerminalResize { + t.Fatalf("type = %q", env.Type) + } + var pl protocol.TerminalResizePayload + if err := json.Unmarshal(env.Payload, &pl); err != nil { + t.Fatal(err) + } + if pl.Cols != 100 || pl.Rows != 30 || pl.SessionID != "sid" { + t.Fatalf("payload = %+v", pl) + } +} + +// Sanity check the help string does not crash on a zero escape byte. +func TestEscapeHelpString(t *testing.T) { + if got := escapeHelpString(0); got != "(disabled)" { + t.Errorf("escape disabled hint = %q", got) + } + if got := escapeHelpString('~'); !strings.Contains(got, "~") { + t.Errorf("escape help = %q", got) + } +} + diff --git a/server/cmd/multica/cmd_issue_terminal_unix.go b/server/cmd/multica/cmd_issue_terminal_unix.go new file mode 100644 index 000000000..3d820e2c4 --- /dev/null +++ b/server/cmd/multica/cmd_issue_terminal_unix.go @@ -0,0 +1,40 @@ +//go:build !windows + +package main + +import ( + "os" + "os/signal" + "syscall" + + "golang.org/x/term" +) + +// startResizeWatcher installs a SIGWINCH handler that pushes the new local +// terminal size to the daemon every time the user resizes their window. +// Returns a stop function that uninstalls the handler and exits the +// goroutine. On platforms without SIGWINCH (Windows) the windows-tagged +// implementation polls term.GetSize on a timer instead. +func startResizeWatcher(p *cliTerminalProxy) func() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGWINCH) + stop := make(chan struct{}) + + go func() { + for { + select { + case <-ch: + if c, r, err := term.GetSize(int(os.Stdout.Fd())); err == nil && c > 0 && r > 0 { + _ = p.sendResize(uint16(c), uint16(r)) + } + case <-stop: + return + } + } + }() + + return func() { + signal.Stop(ch) + close(stop) + } +} diff --git a/server/cmd/multica/cmd_issue_terminal_windows.go b/server/cmd/multica/cmd_issue_terminal_windows.go new file mode 100644 index 000000000..b5c41971c --- /dev/null +++ b/server/cmd/multica/cmd_issue_terminal_windows.go @@ -0,0 +1,39 @@ +//go:build windows + +package main + +import ( + "os" + "time" + + "golang.org/x/term" +) + +// startResizeWatcher polls the local terminal size on a timer, since +// Windows has no SIGWINCH equivalent that is reliable for console resize +// events. 500ms is a compromise between responsiveness and CPU cost. +func startResizeWatcher(p *cliTerminalProxy) func() { + stop := make(chan struct{}) + go func() { + var lastC, lastR int + t := time.NewTicker(500 * time.Millisecond) + defer t.Stop() + for { + select { + case <-stop: + return + case <-t.C: + c, r, err := term.GetSize(int(os.Stdout.Fd())) + if err != nil || c <= 0 || r <= 0 { + continue + } + if c == lastC && r == lastR { + continue + } + lastC, lastR = c, r + _ = p.sendResize(uint16(c), uint16(r)) + } + } + }() + return func() { close(stop) } +} diff --git a/server/go.mod b/server/go.mod index 5aa7c14a7..31cc36169 100644 --- a/server/go.mod +++ b/server/go.mod @@ -24,6 +24,7 @@ require ( github.com/resend/resend-go/v2 v2.28.0 github.com/robfig/cron/v3 v3.0.1 github.com/spf13/cobra v1.10.2 + golang.org/x/term v0.43.0 ) require ( @@ -58,7 +59,7 @@ require ( go.uber.org/atomic v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.35.0 // indirect + golang.org/x/sys v0.44.0 // indirect golang.org/x/text v0.35.0 // indirect google.golang.org/protobuf v1.36.8 // indirect ) diff --git a/server/go.sum b/server/go.sum index d695205ed..5ea569427 100644 --- a/server/go.sum +++ b/server/go.sum @@ -137,8 +137,10 @@ go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= diff --git a/server/internal/cli/ws.go b/server/internal/cli/ws.go new file mode 100644 index 000000000..35de04ab5 --- /dev/null +++ b/server/internal/cli/ws.go @@ -0,0 +1,97 @@ +package cli + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/gorilla/websocket" +) + +// DialWebSocket opens a WebSocket connection to the server at the given path +// + query string. The path must start with "/". Auth is intentionally NOT +// sent as a header here: the server's terminal endpoint runs WS upgrade +// before applying header-based auth middleware (browsers cannot set +// Authorization on a WS upgrade), so the caller authenticates via the +// first-frame `auth` message instead. The standard X-Workspace-ID / +// X-Client-* identity headers are still attached so dashboards can attribute +// the connection to the right CLI build. +func (c *APIClient) DialWebSocket(ctx context.Context, pathAndQuery string) (*websocket.Conn, *http.Response, error) { + if c.BaseURL == "" { + return nil, nil, fmt.Errorf("APIClient has no BaseURL") + } + wsURL, err := httpToWSURL(c.BaseURL, pathAndQuery) + if err != nil { + return nil, nil, err + } + + header := http.Header{} + c.setWSHeaders(header) + + dialer := *websocket.DefaultDialer + conn, resp, err := dialer.DialContext(ctx, wsURL, header) + if err != nil { + return nil, resp, err + } + return conn, resp, nil +} + +// setWSHeaders attaches identity headers but deliberately omits the +// Authorization header. Auth happens in-band via the first frame so this +// stays consistent with cookie-based browser clients. +func (c *APIClient) setWSHeaders(h http.Header) { + if c.WorkspaceID != "" { + h.Set("X-Workspace-ID", c.WorkspaceID) + } + platform := c.Platform + if platform == "" { + platform = ClientPlatform + } + if platform != "" { + h.Set("X-Client-Platform", platform) + } + version := c.Version + if version == "" { + version = ClientVersion + } + if version != "" { + h.Set("X-Client-Version", version) + } + osName := c.OS + if osName == "" { + osName = ClientOS + } + if osName != "" { + h.Set("X-Client-OS", osName) + } +} + +func httpToWSURL(baseURL, pathAndQuery string) (string, error) { + u, err := url.Parse(baseURL) + if err != nil { + return "", fmt.Errorf("parse base URL: %w", err) + } + switch strings.ToLower(u.Scheme) { + case "http": + u.Scheme = "ws" + case "https": + u.Scheme = "wss" + case "ws", "wss": + // already WS + default: + return "", fmt.Errorf("unsupported base URL scheme %q", u.Scheme) + } + if !strings.HasPrefix(pathAndQuery, "/") { + return "", fmt.Errorf("path must start with /, got %q", pathAndQuery) + } + suffix, err := url.Parse(pathAndQuery) + if err != nil { + return "", fmt.Errorf("parse path/query: %w", err) + } + u.Path = strings.TrimRight(u.Path, "/") + suffix.Path + u.RawQuery = suffix.RawQuery + u.Fragment = "" + return u.String(), nil +} diff --git a/server/internal/cli/ws_test.go b/server/internal/cli/ws_test.go new file mode 100644 index 000000000..4085d6723 --- /dev/null +++ b/server/internal/cli/ws_test.go @@ -0,0 +1,125 @@ +package cli + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/websocket" +) + +func TestHTTPToWSURL(t *testing.T) { + cases := []struct { + name string + base string + path string + want string + wantErr bool + }{ + { + name: "https → wss", + base: "https://api.example.com", + path: "/ws/issues/abc/terminal?workspace_id=ws1&cols=80", + want: "wss://api.example.com/ws/issues/abc/terminal?workspace_id=ws1&cols=80", + }, + { + name: "http → ws", + base: "http://localhost:8080", + path: "/ws/issues/x/terminal", + want: "ws://localhost:8080/ws/issues/x/terminal", + }, + { + name: "wss left alone", + base: "wss://api.example.com", + path: "/ws", + want: "wss://api.example.com/ws", + }, + { + name: "trailing slash on base preserved correctly", + base: "https://api.example.com/", + path: "/ws/x", + want: "wss://api.example.com/ws/x", + }, + { + name: "missing leading slash on path", + base: "https://api.example.com", + path: "ws/x", + wantErr: true, + }, + { + name: "unsupported scheme", + base: "ftp://example.com", + path: "/ws", + wantErr: true, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := httpToWSURL(tc.base, tc.path) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error, got %q", got) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.want { + t.Fatalf("got %q want %q", got, tc.want) + } + }) + } +} + +func TestDialWebSocketAttachesIdentityHeaders(t *testing.T) { + upgrader := websocket.Upgrader{} + gotHeaders := make(chan http.Header, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeaders <- r.Header.Clone() + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + conn.Close() + })) + defer server.Close() + + client := NewAPIClient(server.URL, "ws-uuid", "mul_test_token") + client.Platform = "cli" + client.Version = "1.2.3" + client.OS = "macos" + + conn, _, err := client.DialWebSocket(context.Background(), "/ws") + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + headers := <-gotHeaders + if got := headers.Get("X-Workspace-ID"); got != "ws-uuid" { + t.Errorf("X-Workspace-ID = %q, want ws-uuid", got) + } + if got := headers.Get("X-Client-Platform"); got != "cli" { + t.Errorf("X-Client-Platform = %q, want cli", got) + } + if got := headers.Get("X-Client-Version"); got != "1.2.3" { + t.Errorf("X-Client-Version = %q, want 1.2.3", got) + } + if got := headers.Get("X-Client-OS"); got != "macos" { + t.Errorf("X-Client-OS = %q, want macos", got) + } + if got := headers.Get("Authorization"); got != "" { + // The server's terminal endpoint runs WS upgrade before any header + // auth middleware, so the CLI must authenticate via the first frame + // to match cookie-based browser clients. Sending a Bearer header + // here would silently work in some setups and silently fail in + // others — keep it consistent and absent. + t.Errorf("Authorization header should NOT be set on WS dial, got %q", got) + } + if got := headers.Get("Sec-WebSocket-Key"); !strings.HasPrefix(strings.TrimSpace(got), "") || got == "" { + t.Errorf("Sec-WebSocket-Key missing") + } +}