feat(cli): multica issue terminal — attach via Phase 2 WS endpoint (MUL-2295)

Phase 3 of MUL-2295. Adds `multica issue terminal <issue-id>` which dials
the Phase 2 /ws/issues/{id}/terminal endpoint, performs first-frame auth
with the existing PAT/JWT, and runs an interactive PTY through the
daemon-side terminal manager from Phase 1. SIGWINCH on unix /
poll on windows pushes resize frames; ssh-style `<enter>~.` detaches.

Co-authored-by: multica-agent <github@multica.ai>
This commit is contained in:
Jiayuan Zhang
2026-05-16 18:32:00 +08:00
parent f675f03fbb
commit cd414a52ea
8 changed files with 1396 additions and 3 deletions

View File

@@ -0,0 +1,563 @@
package main
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/spf13/cobra"
"golang.org/x/term"
"github.com/multica-ai/multica/server/pkg/protocol"
)
var issueTerminalCmd = &cobra.Command{
Use: "terminal <issue-id>",
Short: "Attach to the issue's most recent agent task PTY",
Long: "Open an interactive shell inside the workdir of the issue's most recent agent task. " +
"Reuses the daemon-side PTY manager added in MUL-2295 — the daemon spawns a bash login " +
"shell with CLAUDE_SESSION_ID + MULTICA_{WORKSPACE,ISSUE,TASK,USER}_ID injected so you " +
"can immediately `claude --resume $CLAUDE_SESSION_ID`.\n\n" +
"Detach without closing your shell: type `<enter>~.` (escape sequence). The daemon-side " +
"session is currently torn down on disconnect — see RFC follow-up for `--attach`.",
Args: exactArgs(1),
RunE: runIssueTerminal,
}
const (
terminalDefaultCols = 80
terminalDefaultRows = 24
terminalAuthAckTimeout = 10 * time.Second
terminalOpenAckTimeout = 15 * time.Second
terminalServerWriteWait = 10 * time.Second
terminalServerReadLimit = 1 << 20 // 1 MiB per frame; matches realistic xterm bursts
terminalDetachExitMessage = "[multica] detached — daemon session was torn down"
)
func init() {
issueCmd.AddCommand(issueTerminalCmd)
issueTerminalCmd.Flags().Uint16("cols", 0, "Initial terminal columns (defaults to detected size, or 80 if stdout is not a TTY)")
issueTerminalCmd.Flags().Uint16("rows", 0, "Initial terminal rows (defaults to detected size, or 24 if stdout is not a TTY)")
issueTerminalCmd.Flags().String("escape-char", "~", "Escape character for detach sequence (`<enter><esc>.` to detach). Empty disables escape detection.")
issueTerminalCmd.Flags().Bool("no-raw", false, "Don't put the local TTY into raw mode (mostly for testing / piped input)")
}
func runIssueTerminal(cmd *cobra.Command, args []string) error {
client, err := newAPIClient(cmd)
if err != nil {
return err
}
if _, err := requireWorkspaceID(cmd); err != nil {
return err
}
token := resolveToken(cmd)
if token == "" {
return fmt.Errorf("not authenticated: run 'multica login'")
}
resolveCtx, cancelResolve := context.WithTimeout(cmd.Context(), 15*time.Second)
defer cancelResolve()
issueRef, err := resolveIssueRef(resolveCtx, client, args[0])
if err != nil {
return fmt.Errorf("resolve issue: %w", err)
}
// Detect terminal size from stdout (the surface the user actually sees);
// fall back to defaults if stdout is piped. Flag overrides win.
cols, rows := detectInitialSize(cmd)
pathAndQuery := buildTerminalPathAndQuery(issueRef.ID, client.WorkspaceID, cols, rows)
// Use a long-lived context for the WS connection; cancellation is driven
// by the proxy goroutines + signals rather than a timeout.
conn, _, err := client.DialWebSocket(cmd.Context(), pathAndQuery)
if err != nil {
return fmt.Errorf("dial terminal websocket: %w", err)
}
proxy := newCLITerminalProxy(conn, os.Stdin, os.Stdout, os.Stderr, token, cmd)
return proxy.run(cmd.Context(), cols, rows)
}
func detectInitialSize(cmd *cobra.Command) (uint16, uint16) {
cols, _ := cmd.Flags().GetUint16("cols")
rows, _ := cmd.Flags().GetUint16("rows")
if cols > 0 && rows > 0 {
return cols, rows
}
if c, r, err := term.GetSize(int(os.Stdout.Fd())); err == nil && c > 0 && r > 0 {
if cols == 0 {
cols = uint16(c)
}
if rows == 0 {
rows = uint16(r)
}
}
if cols == 0 {
cols = terminalDefaultCols
}
if rows == 0 {
rows = terminalDefaultRows
}
return cols, rows
}
func buildTerminalPathAndQuery(issueID, workspaceID string, cols, rows uint16) string {
q := url.Values{}
q.Set("workspace_id", workspaceID)
q.Set("cols", strconv.FormatUint(uint64(cols), 10))
q.Set("rows", strconv.FormatUint(uint64(rows), 10))
return "/ws/issues/" + url.PathEscape(issueID) + "/terminal?" + q.Encode()
}
// cliTerminalProxy mirrors the server-side terminalProxy: one goroutine
// owns conn writes, one owns conn reads, plus a stdin reader and resize
// watcher. The struct is the only owner of the websocket.Conn; all writes
// go through writeFrame() to keep a single point that holds writeMu.
type cliTerminalProxy struct {
conn *websocket.Conn
stdin io.Reader
stdout io.Writer
stderr io.Writer
token string
cmd *cobra.Command
writeMu sync.Mutex
sessionMu sync.RWMutex
sessionID string
closeOnce sync.Once
doneCh chan struct{}
// exit reporting from the read pump back to the orchestrator.
exitCode atomic.Int32 // 0 = unset, see exitCodeUnset / >=1
exitMsg atomic.Pointer[string]
escapeChar byte
noRaw bool
}
const exitCodeUnset int32 = -1
func newCLITerminalProxy(conn *websocket.Conn, stdin io.Reader, stdout, stderr io.Writer, token string, cmd *cobra.Command) *cliTerminalProxy {
escape, _ := cmd.Flags().GetString("escape-char")
noRaw, _ := cmd.Flags().GetBool("no-raw")
var ec byte
if len(escape) >= 1 {
ec = escape[0]
}
p := &cliTerminalProxy{
conn: conn,
stdin: stdin,
stdout: stdout,
stderr: stderr,
token: token,
cmd: cmd,
doneCh: make(chan struct{}),
escapeChar: ec,
noRaw: noRaw,
}
p.exitCode.Store(exitCodeUnset)
conn.SetReadLimit(terminalServerReadLimit)
return p
}
func (p *cliTerminalProxy) run(ctx context.Context, cols, rows uint16) error {
defer p.conn.Close()
if err := p.handshake(); err != nil {
return err
}
// Push our local size right after open in case the server's hardcoded
// initial 80x24 didn't match. (Phase 2 server stamps 80x24 on the
// daemon-bound terminal.open frame regardless of query string; sending
// resize immediately makes the PTY render correctly.)
if err := p.sendResize(cols, rows); err != nil {
// non-fatal — daemon will just keep the original size
fmt.Fprintf(p.stderr, "[multica] warning: initial resize failed: %v\n", err)
}
rawTTY := !p.noRaw && term.IsTerminal(int(os.Stdin.Fd()))
var restore func() error
if rawTTY {
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return fmt.Errorf("enter raw mode: %w", err)
}
fd := int(os.Stdin.Fd())
restore = func() error { return term.Restore(fd, oldState) }
defer restore()
}
stopResize := startResizeWatcher(p)
defer stopResize()
go p.readPump()
go p.stdinPump(rawTTY)
select {
case <-p.doneCh:
case <-ctx.Done():
p.shutdown()
}
if restore != nil {
_ = restore()
}
if msgPtr := p.exitMsg.Load(); msgPtr != nil && *msgPtr != "" {
fmt.Fprintln(p.stderr, *msgPtr)
}
if code := p.exitCode.Load(); code > 0 {
os.Exit(int(code))
}
return nil
}
// handshake performs first-frame auth and waits for terminal.opened.
func (p *cliTerminalProxy) handshake() error {
authFrame, err := json.Marshal(struct {
Type string `json:"type"`
Payload map[string]any `json:"payload"`
}{
Type: "auth",
Payload: map[string]any{"token": p.token},
})
if err != nil {
return fmt.Errorf("marshal auth frame: %w", err)
}
if err := p.writeRawFrame(authFrame); err != nil {
return fmt.Errorf("send auth frame: %w", err)
}
deadline := time.Now().Add(terminalAuthAckTimeout)
if err := p.conn.SetReadDeadline(deadline); err != nil {
return fmt.Errorf("set auth read deadline: %w", err)
}
for {
_, raw, err := p.conn.ReadMessage()
if err != nil {
return fmt.Errorf("read auth response: %w", err)
}
var preview struct {
Type string `json:"type"`
Error string `json:"error"`
}
if err := json.Unmarshal(raw, &preview); err == nil {
if preview.Error != "" {
return fmt.Errorf("auth rejected: %s", preview.Error)
}
if preview.Type == "auth_ack" {
break
}
}
// Tolerate stray frames during handshake (none expected in current
// server implementation, but don't lock up if that changes).
}
// After auth_ack the server proxies a terminal.open to the daemon and
// waits for terminal.opened or terminal.error. Block until we see one.
openDeadline := time.Now().Add(terminalOpenAckTimeout)
if err := p.conn.SetReadDeadline(openDeadline); err != nil {
return fmt.Errorf("set open read deadline: %w", err)
}
for {
_, raw, err := p.conn.ReadMessage()
if err != nil {
return fmt.Errorf("waiting for terminal.opened: %w", err)
}
var env protocol.Message
if err := json.Unmarshal(raw, &env); err != nil {
continue
}
switch env.Type {
case protocol.MessageTypeTerminalOpened:
var op protocol.TerminalOpenedPayload
if err := json.Unmarshal(env.Payload, &op); err != nil {
return fmt.Errorf("decode terminal.opened: %w", err)
}
if op.SessionID == "" {
return fmt.Errorf("daemon returned empty session_id in terminal.opened")
}
p.setSessionID(op.SessionID)
workDir := op.WorkDir
if workDir == "" {
workDir = "(unknown)"
}
fmt.Fprintf(p.stderr, "[multica] attached to %s — escape: %s.\r\n", workDir, escapeHelpString(p.escapeChar))
// Restore non-blocking reads for the pumps.
if err := p.conn.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
return nil
case protocol.MessageTypeTerminalError:
var ep protocol.TerminalErrorPayload
if err := json.Unmarshal(env.Payload, &ep); err != nil {
return fmt.Errorf("daemon returned terminal.error (undecodable)")
}
return fmt.Errorf("daemon rejected terminal.open: %s (%s)", ep.Message, ep.Code)
default:
// keep waiting
}
}
}
func escapeHelpString(b byte) string {
if b == 0 {
return "(disabled)"
}
return "<enter>" + string(b) + "."
}
func (p *cliTerminalProxy) readPump() {
defer p.shutdown()
for {
_, raw, err := p.conn.ReadMessage()
if err != nil {
if !isClosedConnError(err) {
msg := fmt.Sprintf("[multica] websocket closed: %v", err)
p.exitMsg.CompareAndSwap(nil, &msg)
}
return
}
var env protocol.Message
if err := json.Unmarshal(raw, &env); err != nil {
continue
}
switch env.Type {
case protocol.MessageTypeTerminalData:
var pl protocol.TerminalDataPayload
if err := json.Unmarshal(env.Payload, &pl); err != nil {
continue
}
data, err := base64.StdEncoding.DecodeString(pl.DataB64)
if err != nil {
continue
}
_, _ = p.stdout.Write(data)
case protocol.MessageTypeTerminalExit:
var pl protocol.TerminalExitPayload
if err := json.Unmarshal(env.Payload, &pl); err != nil {
continue
}
reason := pl.Reason
if reason == "" {
reason = "child exited"
}
msg := fmt.Sprintf("\r\n[multica] %s (exit code %d)", reason, pl.ExitCode)
p.exitMsg.CompareAndSwap(nil, &msg)
if pl.ExitCode > 0 {
p.exitCode.Store(int32(pl.ExitCode))
}
return
case protocol.MessageTypeTerminalError:
var pl protocol.TerminalErrorPayload
if err := json.Unmarshal(env.Payload, &pl); err != nil {
continue
}
msg := fmt.Sprintf("\r\n[multica] error: %s (%s)", pl.Message, pl.Code)
p.exitMsg.CompareAndSwap(nil, &msg)
p.exitCode.Store(1)
return
case protocol.MessageTypeTerminalClose:
return
}
}
}
// stdinPump reads stdin, runs it through the escape-sequence state machine,
// and forwards bytes as terminal.data frames. Detach (~.) closes the WS
// without sending the bytes.
func (p *cliTerminalProxy) stdinPump(rawTTY bool) {
defer p.shutdown()
buf := make([]byte, 4096)
// Start in newline state so the very first character can trigger an
// escape sequence; mirrors ssh's behavior.
state := newlineState{atNewline: true}
for {
n, err := p.stdin.Read(buf)
if n > 0 {
toSend, detach := state.process(buf[:n], p.escapeChar)
if len(toSend) > 0 {
if err := p.sendData(toSend); err != nil {
return
}
}
if detach {
msg := terminalDetachExitMessage
p.exitMsg.CompareAndSwap(nil, &msg)
_ = p.sendCloseBestEffort("client_detach")
return
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
msg := fmt.Sprintf("[multica] stdin error: %v", err)
p.exitMsg.CompareAndSwap(nil, &msg)
}
return
}
}
}
func (p *cliTerminalProxy) sendData(data []byte) error {
sid := p.SessionID()
if sid == "" {
return errors.New("session_id not set")
}
frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalData, protocol.TerminalDataPayload{
SessionID: sid,
DataB64: base64.StdEncoding.EncodeToString(data),
})
if err != nil {
return err
}
return p.writeRawFrame(frame)
}
func (p *cliTerminalProxy) sendResize(cols, rows uint16) error {
sid := p.SessionID()
if sid == "" {
// Pre-handshake resize is sent later by run() once session is known.
return nil
}
frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalResize, protocol.TerminalResizePayload{
SessionID: sid,
Cols: cols,
Rows: rows,
})
if err != nil {
return err
}
return p.writeRawFrame(frame)
}
func (p *cliTerminalProxy) sendCloseBestEffort(reason string) error {
sid := p.SessionID()
if sid == "" {
return nil
}
frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalClose, protocol.TerminalClosePayload{
SessionID: sid,
Reason: reason,
})
if err != nil {
return err
}
return p.writeRawFrame(frame)
}
func (p *cliTerminalProxy) writeRawFrame(frame []byte) error {
p.writeMu.Lock()
defer p.writeMu.Unlock()
if err := p.conn.SetWriteDeadline(time.Now().Add(terminalServerWriteWait)); err != nil {
return err
}
return p.conn.WriteMessage(websocket.TextMessage, frame)
}
func (p *cliTerminalProxy) SessionID() string {
p.sessionMu.RLock()
defer p.sessionMu.RUnlock()
return p.sessionID
}
func (p *cliTerminalProxy) setSessionID(sid string) {
p.sessionMu.Lock()
defer p.sessionMu.Unlock()
p.sessionID = sid
}
func (p *cliTerminalProxy) shutdown() {
p.closeOnce.Do(func() {
close(p.doneCh)
_ = p.conn.Close()
})
}
func marshalCLITerminalFrame(msgType string, payload any) ([]byte, error) {
raw, err := json.Marshal(payload)
if err != nil {
return nil, err
}
return json.Marshal(protocol.Message{Type: msgType, Payload: raw})
}
func isClosedConnError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, io.EOF) {
return true
}
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
return true
}
return false
}
// --- escape sequence state machine -----------------------------------------
//
// Mirrors ssh(1)'s `~.` detach: after a newline, a single escape character
// followed by `.` detaches; `~~` emits a literal escape; `~?` prints help;
// any other byte aborts the escape and forwards both bytes.
type newlineState struct {
atNewline bool
gotEscape bool
}
// process consumes a chunk of stdin bytes. Returns the bytes that should
// actually be forwarded to the daemon and whether the user requested detach.
// The state machine mutates the receiver across calls so multi-byte chunks
// straddling escape boundaries (rare, but possible with paste) work.
func (s *newlineState) process(in []byte, escape byte) (out []byte, detach bool) {
if escape == 0 {
// Escape detection disabled — pass through.
return in, false
}
out = make([]byte, 0, len(in))
for _, b := range in {
switch {
case s.gotEscape:
s.gotEscape = false
switch b {
case '.':
return out, true
case escape:
out = append(out, escape)
s.atNewline = false
case '?':
// Help is a local-only signal — not delivered to PTY.
// Caller can detect by … actually keep it simple: just
// emit a CR for visual feedback so the prompt redraws.
out = append(out, '\r')
s.atNewline = true
default:
// Not a recognized escape: forward ESC then this byte.
out = append(out, escape, b)
s.atNewline = b == '\r' || b == '\n'
}
case s.atNewline && b == escape:
s.gotEscape = true
default:
out = append(out, b)
s.atNewline = b == '\r' || b == '\n'
}
}
return out, false
}

View File

@@ -0,0 +1,526 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/spf13/cobra"
"github.com/multica-ai/multica/server/pkg/protocol"
)
func TestEscapeState_DetachOnFreshLine(t *testing.T) {
s := &newlineState{atNewline: true}
out, detach := s.process([]byte("~."), '~')
if !detach {
t.Fatalf("expected detach")
}
if len(out) != 0 {
t.Fatalf("expected no bytes forwarded, got %q", out)
}
}
func TestEscapeState_TildeNotAfterNewlineIsLiteral(t *testing.T) {
s := &newlineState{atNewline: false}
out, detach := s.process([]byte("foo~.bar"), '~')
if detach {
t.Fatalf("must not detach when ~ is mid-line")
}
if string(out) != "foo~.bar" {
t.Fatalf("got %q", out)
}
}
func TestEscapeState_DoubleTildeEmitsLiteral(t *testing.T) {
s := &newlineState{atNewline: true}
out, detach := s.process([]byte("~~"), '~')
if detach {
t.Fatalf("~~ must not detach")
}
if string(out) != "~" {
t.Fatalf("got %q want ~", out)
}
}
func TestEscapeState_StraddledChunks(t *testing.T) {
// User pastes/types ~ and . in two separate stdin reads — escape
// detection still works because state is preserved across calls.
s := &newlineState{atNewline: true}
out1, detach1 := s.process([]byte("~"), '~')
if detach1 || len(out1) != 0 {
t.Fatalf("first chunk: detach=%v out=%q", detach1, out1)
}
out2, detach2 := s.process([]byte("."), '~')
if !detach2 {
t.Fatalf("expected detach on second chunk")
}
if len(out2) != 0 {
t.Fatalf("second chunk should forward nothing, got %q", out2)
}
}
func TestEscapeState_DisabledWhenEscapeIsZero(t *testing.T) {
s := &newlineState{atNewline: true}
out, detach := s.process([]byte("~."), 0)
if detach {
t.Fatalf("disabled escape must not detach")
}
if string(out) != "~." {
t.Fatalf("got %q want ~.", out)
}
}
func TestEscapeState_UnknownEscapeForwardsBoth(t *testing.T) {
s := &newlineState{atNewline: true}
out, _ := s.process([]byte("~x"), '~')
if string(out) != "~x" {
t.Fatalf("got %q want ~x", out)
}
}
func TestBuildTerminalPathAndQuery(t *testing.T) {
got := buildTerminalPathAndQuery("MUL-2295", "ws-uuid", 120, 40)
u, err := url.Parse("http://x" + got)
if err != nil {
t.Fatalf("parse: %v", err)
}
if u.Path != "/ws/issues/MUL-2295/terminal" {
t.Errorf("path = %q", u.Path)
}
q := u.Query()
if q.Get("workspace_id") != "ws-uuid" {
t.Errorf("workspace_id = %q", q.Get("workspace_id"))
}
if q.Get("cols") != "120" {
t.Errorf("cols = %q", q.Get("cols"))
}
if q.Get("rows") != "40" {
t.Errorf("rows = %q", q.Get("rows"))
}
}
// fakeServer simulates the Phase 2 /ws/issues/{id}/terminal handshake plus
// a tiny echo loop, so we can drive the CLI proxy through its full lifecycle
// in-process without spinning up the real daemon.
type fakeServer struct {
t *testing.T
upgrader websocket.Upgrader
gotAuth chan string
gotData chan []byte
gotClose chan string
sessionID string
server *httptest.Server
connMu sync.Mutex
conn *websocket.Conn
sendOpenErr *protocol.TerminalErrorPayload // if set, send terminal.error instead of terminal.opened
}
// writeFrame serializes writes from the handler goroutine and any test
// goroutine that wants to push a frame to the connected client. Required
// because gorilla/websocket allows concurrent read+write but NOT concurrent
// writes from different goroutines.
func (fs *fakeServer) writeFrame(frame []byte) error {
fs.connMu.Lock()
defer fs.connMu.Unlock()
if fs.conn == nil {
return fmt.Errorf("no client")
}
return fs.conn.WriteMessage(websocket.TextMessage, frame)
}
func newFakeServer(t *testing.T) *fakeServer {
fs := &fakeServer{
t: t,
upgrader: websocket.Upgrader{},
gotAuth: make(chan string, 1),
gotData: make(chan []byte, 32),
gotClose: make(chan string, 1),
sessionID: "session-xyz",
}
fs.server = httptest.NewServer(http.HandlerFunc(fs.handle))
return fs
}
func (fs *fakeServer) close() {
fs.connMu.Lock()
c := fs.conn
fs.connMu.Unlock()
if c != nil {
c.Close()
}
fs.server.Close()
}
func (fs *fakeServer) baseURL() string { return fs.server.URL }
func (fs *fakeServer) handle(w http.ResponseWriter, r *http.Request) {
conn, err := fs.upgrader.Upgrade(w, r, nil)
if err != nil {
fs.t.Errorf("upgrade: %v", err)
return
}
fs.connMu.Lock()
fs.conn = conn
fs.connMu.Unlock()
// 1. Auth.
_, raw, err := conn.ReadMessage()
if err != nil {
return
}
var auth struct {
Type string `json:"type"`
Payload map[string]any `json:"payload"`
}
if err := json.Unmarshal(raw, &auth); err != nil || auth.Type != "auth" {
_ = fs.writeFrame([]byte(`{"error":"bad auth"}`))
return
}
tok, _ := auth.Payload["token"].(string)
fs.gotAuth <- tok
_ = fs.writeFrame([]byte(`{"type":"auth_ack"}`))
// 2. Open ack.
if fs.sendOpenErr != nil {
ep := *fs.sendOpenErr
frame, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalError, ep)
_ = fs.writeFrame(frame)
return
}
openedFrame, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalOpened, protocol.TerminalOpenedPayload{
SessionID: fs.sessionID,
WorkDir: "/tmp/work",
Shell: "/bin/bash",
})
_ = fs.writeFrame(openedFrame)
// 3. Pump.
for {
_, raw, err := conn.ReadMessage()
if err != nil {
return
}
var env protocol.Message
if err := json.Unmarshal(raw, &env); err != nil {
continue
}
switch env.Type {
case protocol.MessageTypeTerminalData:
var pl protocol.TerminalDataPayload
if err := json.Unmarshal(env.Payload, &pl); err != nil {
continue
}
data, _ := base64.StdEncoding.DecodeString(pl.DataB64)
fs.gotData <- data
// Echo back so the CLI's stdout pump has something to do.
echo, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalData, protocol.TerminalDataPayload{
SessionID: fs.sessionID,
DataB64: pl.DataB64,
})
_ = fs.writeFrame(echo)
case protocol.MessageTypeTerminalClose:
var pl protocol.TerminalClosePayload
_ = json.Unmarshal(env.Payload, &pl)
fs.gotClose <- pl.Reason
return
case protocol.MessageTypeTerminalResize:
// observed but unused in this fake
}
}
}
func newTestCmd() *cobra.Command {
c := &cobra.Command{}
c.Flags().String("escape-char", "~", "")
c.Flags().Bool("no-raw", true, "")
return c
}
func TestCLITerminalProxy_HandshakeAndEcho(t *testing.T) {
fs := newFakeServer(t)
defer fs.close()
wsURL := strings.Replace(fs.baseURL(), "http://", "ws://", 1) + "/"
dialer := *websocket.DefaultDialer
conn, _, err := dialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial: %v", err)
}
stdinR, stdinW := io.Pipe()
stdout := newSafeBuffer()
stderr := newSafeBuffer()
cmd := newTestCmd()
p := newCLITerminalProxy(conn, stdinR, stdout, stderr, "mul_test", cmd)
// Drive handshake explicitly so we can also assert the auth token reached
// the fake server.
if err := p.handshake(); err != nil {
t.Fatalf("handshake: %v", err)
}
select {
case got := <-fs.gotAuth:
if got != "mul_test" {
t.Errorf("auth token = %q, want mul_test", got)
}
case <-time.After(2 * time.Second):
t.Fatal("server did not receive auth frame")
}
if p.SessionID() != fs.sessionID {
t.Fatalf("session_id = %q, want %q", p.SessionID(), fs.sessionID)
}
// Now run the pumps in a goroutine.
pumpsDone := make(chan struct{})
go func() {
go p.readPump()
p.stdinPump(false)
close(pumpsDone)
}()
// Send "hello" through stdin; expect server to receive it and echo it
// back into stdout.
if _, err := stdinW.Write([]byte("hello")); err != nil {
t.Fatalf("stdin write: %v", err)
}
select {
case got := <-fs.gotData:
if string(got) != "hello" {
t.Fatalf("server got %q, want hello", got)
}
case <-time.After(2 * time.Second):
t.Fatal("server did not receive data")
}
// Wait for the echo to land in stdout.
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if strings.Contains(stdout.String(), "hello") {
break
}
time.Sleep(10 * time.Millisecond)
}
if !strings.Contains(stdout.String(), "hello") {
t.Fatalf("stdout missing echo, got %q", stdout.String())
}
// Trigger detach: send "\n~." after a newline. Because stdinPump starts
// the state machine at atNewline=true on the very first byte, we need
// to walk through a real newline first to make the test realistic.
if _, err := stdinW.Write([]byte("\n~.")); err != nil {
t.Fatalf("stdin write detach: %v", err)
}
select {
case <-pumpsDone:
case <-time.After(3 * time.Second):
t.Fatal("stdin pump did not exit after detach")
}
select {
case reason := <-fs.gotClose:
if reason != "client_detach" {
t.Errorf("close reason = %q, want client_detach", reason)
}
case <-time.After(2 * time.Second):
t.Fatal("server did not receive terminal.close on detach")
}
// run() prints the exit message to stderr; in this lower-level test we
// drive the pumps directly, so check the captured exit message.
msgPtr := p.exitMsg.Load()
if msgPtr == nil || !strings.Contains(*msgPtr, "detached") {
got := ""
if msgPtr != nil {
got = *msgPtr
}
t.Errorf("exit msg = %q, want detach text", got)
}
}
// safeBuffer is a tiny mutex-wrapped bytes.Buffer for tests that read from
// the buffer in one goroutine while another writes (race-detector-clean).
type safeBuffer struct {
mu sync.Mutex
buf bytes.Buffer
}
func newSafeBuffer() *safeBuffer { return &safeBuffer{} }
func (b *safeBuffer) Write(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.Write(p)
}
func (b *safeBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buf.String()
}
func TestCLITerminalProxy_HandshakeRejectedOnTerminalError(t *testing.T) {
fs := newFakeServer(t)
fs.sendOpenErr = &protocol.TerminalErrorPayload{
Code: protocol.TerminalErrorCodeTaskNotFound,
Message: "no agent task on this issue",
}
defer fs.close()
wsURL := strings.Replace(fs.baseURL(), "http://", "ws://", 1) + "/"
dialer := *websocket.DefaultDialer
conn, _, err := dialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial: %v", err)
}
cmd := newTestCmd()
p := newCLITerminalProxy(conn, strings.NewReader(""), io.Discard, io.Discard, "mul_test", cmd)
err = p.handshake()
if err == nil {
t.Fatal("expected handshake error, got nil")
}
if !strings.Contains(err.Error(), protocol.TerminalErrorCodeTaskNotFound) {
t.Errorf("error %q does not mention error code", err)
}
}
func TestCLITerminalProxy_AuthRejected(t *testing.T) {
upgrader := websocket.Upgrader{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
// Read auth frame, reply with error.
_, _, _ = conn.ReadMessage()
_ = conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"invalid token"}`))
}))
defer server.Close()
wsURL := strings.Replace(server.URL, "http://", "ws://", 1) + "/"
dialer := *websocket.DefaultDialer
conn, _, err := dialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial: %v", err)
}
cmd := newTestCmd()
p := newCLITerminalProxy(conn, strings.NewReader(""), io.Discard, io.Discard, "mul_test", cmd)
err = p.handshake()
if err == nil {
t.Fatal("expected handshake error, got nil")
}
if !strings.Contains(err.Error(), "invalid token") {
t.Errorf("error %q does not surface server reason", err)
}
}
func TestCLITerminalProxy_TerminalExitDeliversCode(t *testing.T) {
// Driver: open server, advance through handshake, then push a
// terminal.exit frame and verify the proxy's exit code state.
fs := newFakeServer(t)
defer fs.close()
wsURL := strings.Replace(fs.baseURL(), "http://", "ws://", 1) + "/"
dialer := *websocket.DefaultDialer
conn, _, err := dialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial: %v", err)
}
cmd := newTestCmd()
p := newCLITerminalProxy(conn, strings.NewReader(""), io.Discard, io.Discard, "mul_test", cmd)
if err := p.handshake(); err != nil {
t.Fatalf("handshake: %v", err)
}
exitFrame, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalExit, protocol.TerminalExitPayload{
SessionID: fs.sessionID,
ExitCode: 42,
Reason: "child exited",
})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
p.readPump()
}()
if err := fs.writeFrame(exitFrame); err != nil {
t.Fatalf("server write exit: %v", err)
}
doneAt := time.Now().Add(2 * time.Second)
for time.Now().Before(doneAt) {
if p.exitCode.Load() == 42 {
break
}
time.Sleep(10 * time.Millisecond)
}
wg.Wait()
if got := p.exitCode.Load(); got != 42 {
t.Fatalf("exit code = %d, want 42", got)
}
msgPtr := p.exitMsg.Load()
if msgPtr == nil || !strings.Contains(*msgPtr, "exit code 42") {
got := ""
if msgPtr != nil {
got = *msgPtr
}
t.Errorf("exit msg = %q", got)
}
}
// Compile-time check: ensure the marshaled frame round-trips through the
// real protocol.Message envelope. Catches any drift if the protocol pkg
// renames a field.
func TestMarshalCLITerminalFrame_EnvelopeShape(t *testing.T) {
frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalResize, protocol.TerminalResizePayload{
SessionID: "sid",
Cols: 100,
Rows: 30,
})
if err != nil {
t.Fatal(err)
}
var env protocol.Message
if err := json.Unmarshal(frame, &env); err != nil {
t.Fatal(err)
}
if env.Type != protocol.MessageTypeTerminalResize {
t.Fatalf("type = %q", env.Type)
}
var pl protocol.TerminalResizePayload
if err := json.Unmarshal(env.Payload, &pl); err != nil {
t.Fatal(err)
}
if pl.Cols != 100 || pl.Rows != 30 || pl.SessionID != "sid" {
t.Fatalf("payload = %+v", pl)
}
}
// Sanity check the help string does not crash on a zero escape byte.
func TestEscapeHelpString(t *testing.T) {
if got := escapeHelpString(0); got != "(disabled)" {
t.Errorf("escape disabled hint = %q", got)
}
if got := escapeHelpString('~'); !strings.Contains(got, "~") {
t.Errorf("escape help = %q", got)
}
}

View File

@@ -0,0 +1,40 @@
//go:build !windows
package main
import (
"os"
"os/signal"
"syscall"
"golang.org/x/term"
)
// startResizeWatcher installs a SIGWINCH handler that pushes the new local
// terminal size to the daemon every time the user resizes their window.
// Returns a stop function that uninstalls the handler and exits the
// goroutine. On platforms without SIGWINCH (Windows) the windows-tagged
// implementation polls term.GetSize on a timer instead.
func startResizeWatcher(p *cliTerminalProxy) func() {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGWINCH)
stop := make(chan struct{})
go func() {
for {
select {
case <-ch:
if c, r, err := term.GetSize(int(os.Stdout.Fd())); err == nil && c > 0 && r > 0 {
_ = p.sendResize(uint16(c), uint16(r))
}
case <-stop:
return
}
}
}()
return func() {
signal.Stop(ch)
close(stop)
}
}

View File

@@ -0,0 +1,39 @@
//go:build windows
package main
import (
"os"
"time"
"golang.org/x/term"
)
// startResizeWatcher polls the local terminal size on a timer, since
// Windows has no SIGWINCH equivalent that is reliable for console resize
// events. 500ms is a compromise between responsiveness and CPU cost.
func startResizeWatcher(p *cliTerminalProxy) func() {
stop := make(chan struct{})
go func() {
var lastC, lastR int
t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
for {
select {
case <-stop:
return
case <-t.C:
c, r, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil || c <= 0 || r <= 0 {
continue
}
if c == lastC && r == lastR {
continue
}
lastC, lastR = c, r
_ = p.sendResize(uint16(c), uint16(r))
}
}
}()
return func() { close(stop) }
}

View File

@@ -24,6 +24,7 @@ require (
github.com/resend/resend-go/v2 v2.28.0
github.com/robfig/cron/v3 v3.0.1
github.com/spf13/cobra v1.10.2
golang.org/x/term v0.43.0
)
require (
@@ -58,7 +59,7 @@ require (
go.uber.org/atomic v1.11.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/sys v0.44.0 // indirect
golang.org/x/text v0.35.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
)

View File

@@ -137,8 +137,10 @@ go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=

97
server/internal/cli/ws.go Normal file
View File

@@ -0,0 +1,97 @@
package cli
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/gorilla/websocket"
)
// DialWebSocket opens a WebSocket connection to the server at the given path
// + query string. The path must start with "/". Auth is intentionally NOT
// sent as a header here: the server's terminal endpoint runs WS upgrade
// before applying header-based auth middleware (browsers cannot set
// Authorization on a WS upgrade), so the caller authenticates via the
// first-frame `auth` message instead. The standard X-Workspace-ID /
// X-Client-* identity headers are still attached so dashboards can attribute
// the connection to the right CLI build.
func (c *APIClient) DialWebSocket(ctx context.Context, pathAndQuery string) (*websocket.Conn, *http.Response, error) {
if c.BaseURL == "" {
return nil, nil, fmt.Errorf("APIClient has no BaseURL")
}
wsURL, err := httpToWSURL(c.BaseURL, pathAndQuery)
if err != nil {
return nil, nil, err
}
header := http.Header{}
c.setWSHeaders(header)
dialer := *websocket.DefaultDialer
conn, resp, err := dialer.DialContext(ctx, wsURL, header)
if err != nil {
return nil, resp, err
}
return conn, resp, nil
}
// setWSHeaders attaches identity headers but deliberately omits the
// Authorization header. Auth happens in-band via the first frame so this
// stays consistent with cookie-based browser clients.
func (c *APIClient) setWSHeaders(h http.Header) {
if c.WorkspaceID != "" {
h.Set("X-Workspace-ID", c.WorkspaceID)
}
platform := c.Platform
if platform == "" {
platform = ClientPlatform
}
if platform != "" {
h.Set("X-Client-Platform", platform)
}
version := c.Version
if version == "" {
version = ClientVersion
}
if version != "" {
h.Set("X-Client-Version", version)
}
osName := c.OS
if osName == "" {
osName = ClientOS
}
if osName != "" {
h.Set("X-Client-OS", osName)
}
}
func httpToWSURL(baseURL, pathAndQuery string) (string, error) {
u, err := url.Parse(baseURL)
if err != nil {
return "", fmt.Errorf("parse base URL: %w", err)
}
switch strings.ToLower(u.Scheme) {
case "http":
u.Scheme = "ws"
case "https":
u.Scheme = "wss"
case "ws", "wss":
// already WS
default:
return "", fmt.Errorf("unsupported base URL scheme %q", u.Scheme)
}
if !strings.HasPrefix(pathAndQuery, "/") {
return "", fmt.Errorf("path must start with /, got %q", pathAndQuery)
}
suffix, err := url.Parse(pathAndQuery)
if err != nil {
return "", fmt.Errorf("parse path/query: %w", err)
}
u.Path = strings.TrimRight(u.Path, "/") + suffix.Path
u.RawQuery = suffix.RawQuery
u.Fragment = ""
return u.String(), nil
}

View File

@@ -0,0 +1,125 @@
package cli
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/websocket"
)
func TestHTTPToWSURL(t *testing.T) {
cases := []struct {
name string
base string
path string
want string
wantErr bool
}{
{
name: "https → wss",
base: "https://api.example.com",
path: "/ws/issues/abc/terminal?workspace_id=ws1&cols=80",
want: "wss://api.example.com/ws/issues/abc/terminal?workspace_id=ws1&cols=80",
},
{
name: "http → ws",
base: "http://localhost:8080",
path: "/ws/issues/x/terminal",
want: "ws://localhost:8080/ws/issues/x/terminal",
},
{
name: "wss left alone",
base: "wss://api.example.com",
path: "/ws",
want: "wss://api.example.com/ws",
},
{
name: "trailing slash on base preserved correctly",
base: "https://api.example.com/",
path: "/ws/x",
want: "wss://api.example.com/ws/x",
},
{
name: "missing leading slash on path",
base: "https://api.example.com",
path: "ws/x",
wantErr: true,
},
{
name: "unsupported scheme",
base: "ftp://example.com",
path: "/ws",
wantErr: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got, err := httpToWSURL(tc.base, tc.path)
if tc.wantErr {
if err == nil {
t.Fatalf("expected error, got %q", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tc.want {
t.Fatalf("got %q want %q", got, tc.want)
}
})
}
}
func TestDialWebSocketAttachesIdentityHeaders(t *testing.T) {
upgrader := websocket.Upgrader{}
gotHeaders := make(chan http.Header, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
conn.Close()
}))
defer server.Close()
client := NewAPIClient(server.URL, "ws-uuid", "mul_test_token")
client.Platform = "cli"
client.Version = "1.2.3"
client.OS = "macos"
conn, _, err := client.DialWebSocket(context.Background(), "/ws")
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
headers := <-gotHeaders
if got := headers.Get("X-Workspace-ID"); got != "ws-uuid" {
t.Errorf("X-Workspace-ID = %q, want ws-uuid", got)
}
if got := headers.Get("X-Client-Platform"); got != "cli" {
t.Errorf("X-Client-Platform = %q, want cli", got)
}
if got := headers.Get("X-Client-Version"); got != "1.2.3" {
t.Errorf("X-Client-Version = %q, want 1.2.3", got)
}
if got := headers.Get("X-Client-OS"); got != "macos" {
t.Errorf("X-Client-OS = %q, want macos", got)
}
if got := headers.Get("Authorization"); got != "" {
// The server's terminal endpoint runs WS upgrade before any header
// auth middleware, so the CLI must authenticate via the first frame
// to match cookie-based browser clients. Sending a Bearer header
// here would silently work in some setups and silently fail in
// others — keep it consistent and absent.
t.Errorf("Authorization header should NOT be set on WS dial, got %q", got)
}
if got := headers.Get("Sec-WebSocket-Key"); !strings.HasPrefix(strings.TrimSpace(got), "") || got == "" {
t.Errorf("Sec-WebSocket-Key missing")
}
}