mirror of
https://github.com/multica-ai/multica.git
synced 2026-07-05 13:29:44 +02:00
feat(cli): multica issue terminal — attach via Phase 2 WS endpoint (MUL-2295)
Phase 3 of MUL-2295. Adds `multica issue terminal <issue-id>` which dials
the Phase 2 /ws/issues/{id}/terminal endpoint, performs first-frame auth
with the existing PAT/JWT, and runs an interactive PTY through the
daemon-side terminal manager from Phase 1. SIGWINCH on unix /
poll on windows pushes resize frames; ssh-style `<enter>~.` detaches.
Co-authored-by: multica-agent <github@multica.ai>
This commit is contained in:
563
server/cmd/multica/cmd_issue_terminal.go
Normal file
563
server/cmd/multica/cmd_issue_terminal.go
Normal file
@@ -0,0 +1,563 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/multica-ai/multica/server/pkg/protocol"
|
||||
)
|
||||
|
||||
var issueTerminalCmd = &cobra.Command{
|
||||
Use: "terminal <issue-id>",
|
||||
Short: "Attach to the issue's most recent agent task PTY",
|
||||
Long: "Open an interactive shell inside the workdir of the issue's most recent agent task. " +
|
||||
"Reuses the daemon-side PTY manager added in MUL-2295 — the daemon spawns a bash login " +
|
||||
"shell with CLAUDE_SESSION_ID + MULTICA_{WORKSPACE,ISSUE,TASK,USER}_ID injected so you " +
|
||||
"can immediately `claude --resume $CLAUDE_SESSION_ID`.\n\n" +
|
||||
"Detach without closing your shell: type `<enter>~.` (escape sequence). The daemon-side " +
|
||||
"session is currently torn down on disconnect — see RFC follow-up for `--attach`.",
|
||||
Args: exactArgs(1),
|
||||
RunE: runIssueTerminal,
|
||||
}
|
||||
|
||||
const (
|
||||
terminalDefaultCols = 80
|
||||
terminalDefaultRows = 24
|
||||
terminalAuthAckTimeout = 10 * time.Second
|
||||
terminalOpenAckTimeout = 15 * time.Second
|
||||
terminalServerWriteWait = 10 * time.Second
|
||||
terminalServerReadLimit = 1 << 20 // 1 MiB per frame; matches realistic xterm bursts
|
||||
terminalDetachExitMessage = "[multica] detached — daemon session was torn down"
|
||||
)
|
||||
|
||||
func init() {
|
||||
issueCmd.AddCommand(issueTerminalCmd)
|
||||
issueTerminalCmd.Flags().Uint16("cols", 0, "Initial terminal columns (defaults to detected size, or 80 if stdout is not a TTY)")
|
||||
issueTerminalCmd.Flags().Uint16("rows", 0, "Initial terminal rows (defaults to detected size, or 24 if stdout is not a TTY)")
|
||||
issueTerminalCmd.Flags().String("escape-char", "~", "Escape character for detach sequence (`<enter><esc>.` to detach). Empty disables escape detection.")
|
||||
issueTerminalCmd.Flags().Bool("no-raw", false, "Don't put the local TTY into raw mode (mostly for testing / piped input)")
|
||||
}
|
||||
|
||||
func runIssueTerminal(cmd *cobra.Command, args []string) error {
|
||||
client, err := newAPIClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := requireWorkspaceID(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
token := resolveToken(cmd)
|
||||
if token == "" {
|
||||
return fmt.Errorf("not authenticated: run 'multica login'")
|
||||
}
|
||||
|
||||
resolveCtx, cancelResolve := context.WithTimeout(cmd.Context(), 15*time.Second)
|
||||
defer cancelResolve()
|
||||
issueRef, err := resolveIssueRef(resolveCtx, client, args[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve issue: %w", err)
|
||||
}
|
||||
|
||||
// Detect terminal size from stdout (the surface the user actually sees);
|
||||
// fall back to defaults if stdout is piped. Flag overrides win.
|
||||
cols, rows := detectInitialSize(cmd)
|
||||
|
||||
pathAndQuery := buildTerminalPathAndQuery(issueRef.ID, client.WorkspaceID, cols, rows)
|
||||
|
||||
// Use a long-lived context for the WS connection; cancellation is driven
|
||||
// by the proxy goroutines + signals rather than a timeout.
|
||||
conn, _, err := client.DialWebSocket(cmd.Context(), pathAndQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial terminal websocket: %w", err)
|
||||
}
|
||||
|
||||
proxy := newCLITerminalProxy(conn, os.Stdin, os.Stdout, os.Stderr, token, cmd)
|
||||
return proxy.run(cmd.Context(), cols, rows)
|
||||
}
|
||||
|
||||
func detectInitialSize(cmd *cobra.Command) (uint16, uint16) {
|
||||
cols, _ := cmd.Flags().GetUint16("cols")
|
||||
rows, _ := cmd.Flags().GetUint16("rows")
|
||||
if cols > 0 && rows > 0 {
|
||||
return cols, rows
|
||||
}
|
||||
if c, r, err := term.GetSize(int(os.Stdout.Fd())); err == nil && c > 0 && r > 0 {
|
||||
if cols == 0 {
|
||||
cols = uint16(c)
|
||||
}
|
||||
if rows == 0 {
|
||||
rows = uint16(r)
|
||||
}
|
||||
}
|
||||
if cols == 0 {
|
||||
cols = terminalDefaultCols
|
||||
}
|
||||
if rows == 0 {
|
||||
rows = terminalDefaultRows
|
||||
}
|
||||
return cols, rows
|
||||
}
|
||||
|
||||
func buildTerminalPathAndQuery(issueID, workspaceID string, cols, rows uint16) string {
|
||||
q := url.Values{}
|
||||
q.Set("workspace_id", workspaceID)
|
||||
q.Set("cols", strconv.FormatUint(uint64(cols), 10))
|
||||
q.Set("rows", strconv.FormatUint(uint64(rows), 10))
|
||||
return "/ws/issues/" + url.PathEscape(issueID) + "/terminal?" + q.Encode()
|
||||
}
|
||||
|
||||
// cliTerminalProxy mirrors the server-side terminalProxy: one goroutine
|
||||
// owns conn writes, one owns conn reads, plus a stdin reader and resize
|
||||
// watcher. The struct is the only owner of the websocket.Conn; all writes
|
||||
// go through writeFrame() to keep a single point that holds writeMu.
|
||||
type cliTerminalProxy struct {
|
||||
conn *websocket.Conn
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
stderr io.Writer
|
||||
token string
|
||||
cmd *cobra.Command
|
||||
|
||||
writeMu sync.Mutex
|
||||
|
||||
sessionMu sync.RWMutex
|
||||
sessionID string
|
||||
|
||||
closeOnce sync.Once
|
||||
doneCh chan struct{}
|
||||
|
||||
// exit reporting from the read pump back to the orchestrator.
|
||||
exitCode atomic.Int32 // 0 = unset, see exitCodeUnset / >=1
|
||||
exitMsg atomic.Pointer[string]
|
||||
|
||||
escapeChar byte
|
||||
noRaw bool
|
||||
}
|
||||
|
||||
const exitCodeUnset int32 = -1
|
||||
|
||||
func newCLITerminalProxy(conn *websocket.Conn, stdin io.Reader, stdout, stderr io.Writer, token string, cmd *cobra.Command) *cliTerminalProxy {
|
||||
escape, _ := cmd.Flags().GetString("escape-char")
|
||||
noRaw, _ := cmd.Flags().GetBool("no-raw")
|
||||
var ec byte
|
||||
if len(escape) >= 1 {
|
||||
ec = escape[0]
|
||||
}
|
||||
p := &cliTerminalProxy{
|
||||
conn: conn,
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
token: token,
|
||||
cmd: cmd,
|
||||
doneCh: make(chan struct{}),
|
||||
escapeChar: ec,
|
||||
noRaw: noRaw,
|
||||
}
|
||||
p.exitCode.Store(exitCodeUnset)
|
||||
conn.SetReadLimit(terminalServerReadLimit)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *cliTerminalProxy) run(ctx context.Context, cols, rows uint16) error {
|
||||
defer p.conn.Close()
|
||||
|
||||
if err := p.handshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Push our local size right after open in case the server's hardcoded
|
||||
// initial 80x24 didn't match. (Phase 2 server stamps 80x24 on the
|
||||
// daemon-bound terminal.open frame regardless of query string; sending
|
||||
// resize immediately makes the PTY render correctly.)
|
||||
if err := p.sendResize(cols, rows); err != nil {
|
||||
// non-fatal — daemon will just keep the original size
|
||||
fmt.Fprintf(p.stderr, "[multica] warning: initial resize failed: %v\n", err)
|
||||
}
|
||||
|
||||
rawTTY := !p.noRaw && term.IsTerminal(int(os.Stdin.Fd()))
|
||||
var restore func() error
|
||||
if rawTTY {
|
||||
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("enter raw mode: %w", err)
|
||||
}
|
||||
fd := int(os.Stdin.Fd())
|
||||
restore = func() error { return term.Restore(fd, oldState) }
|
||||
defer restore()
|
||||
}
|
||||
|
||||
stopResize := startResizeWatcher(p)
|
||||
defer stopResize()
|
||||
|
||||
go p.readPump()
|
||||
go p.stdinPump(rawTTY)
|
||||
|
||||
select {
|
||||
case <-p.doneCh:
|
||||
case <-ctx.Done():
|
||||
p.shutdown()
|
||||
}
|
||||
|
||||
if restore != nil {
|
||||
_ = restore()
|
||||
}
|
||||
|
||||
if msgPtr := p.exitMsg.Load(); msgPtr != nil && *msgPtr != "" {
|
||||
fmt.Fprintln(p.stderr, *msgPtr)
|
||||
}
|
||||
if code := p.exitCode.Load(); code > 0 {
|
||||
os.Exit(int(code))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handshake performs first-frame auth and waits for terminal.opened.
|
||||
func (p *cliTerminalProxy) handshake() error {
|
||||
authFrame, err := json.Marshal(struct {
|
||||
Type string `json:"type"`
|
||||
Payload map[string]any `json:"payload"`
|
||||
}{
|
||||
Type: "auth",
|
||||
Payload: map[string]any{"token": p.token},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal auth frame: %w", err)
|
||||
}
|
||||
if err := p.writeRawFrame(authFrame); err != nil {
|
||||
return fmt.Errorf("send auth frame: %w", err)
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(terminalAuthAckTimeout)
|
||||
if err := p.conn.SetReadDeadline(deadline); err != nil {
|
||||
return fmt.Errorf("set auth read deadline: %w", err)
|
||||
}
|
||||
for {
|
||||
_, raw, err := p.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read auth response: %w", err)
|
||||
}
|
||||
var preview struct {
|
||||
Type string `json:"type"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &preview); err == nil {
|
||||
if preview.Error != "" {
|
||||
return fmt.Errorf("auth rejected: %s", preview.Error)
|
||||
}
|
||||
if preview.Type == "auth_ack" {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Tolerate stray frames during handshake (none expected in current
|
||||
// server implementation, but don't lock up if that changes).
|
||||
}
|
||||
|
||||
// After auth_ack the server proxies a terminal.open to the daemon and
|
||||
// waits for terminal.opened or terminal.error. Block until we see one.
|
||||
openDeadline := time.Now().Add(terminalOpenAckTimeout)
|
||||
if err := p.conn.SetReadDeadline(openDeadline); err != nil {
|
||||
return fmt.Errorf("set open read deadline: %w", err)
|
||||
}
|
||||
for {
|
||||
_, raw, err := p.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("waiting for terminal.opened: %w", err)
|
||||
}
|
||||
var env protocol.Message
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
continue
|
||||
}
|
||||
switch env.Type {
|
||||
case protocol.MessageTypeTerminalOpened:
|
||||
var op protocol.TerminalOpenedPayload
|
||||
if err := json.Unmarshal(env.Payload, &op); err != nil {
|
||||
return fmt.Errorf("decode terminal.opened: %w", err)
|
||||
}
|
||||
if op.SessionID == "" {
|
||||
return fmt.Errorf("daemon returned empty session_id in terminal.opened")
|
||||
}
|
||||
p.setSessionID(op.SessionID)
|
||||
workDir := op.WorkDir
|
||||
if workDir == "" {
|
||||
workDir = "(unknown)"
|
||||
}
|
||||
fmt.Fprintf(p.stderr, "[multica] attached to %s — escape: %s.\r\n", workDir, escapeHelpString(p.escapeChar))
|
||||
// Restore non-blocking reads for the pumps.
|
||||
if err := p.conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return fmt.Errorf("clear read deadline: %w", err)
|
||||
}
|
||||
return nil
|
||||
case protocol.MessageTypeTerminalError:
|
||||
var ep protocol.TerminalErrorPayload
|
||||
if err := json.Unmarshal(env.Payload, &ep); err != nil {
|
||||
return fmt.Errorf("daemon returned terminal.error (undecodable)")
|
||||
}
|
||||
return fmt.Errorf("daemon rejected terminal.open: %s (%s)", ep.Message, ep.Code)
|
||||
default:
|
||||
// keep waiting
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func escapeHelpString(b byte) string {
|
||||
if b == 0 {
|
||||
return "(disabled)"
|
||||
}
|
||||
return "<enter>" + string(b) + "."
|
||||
}
|
||||
|
||||
func (p *cliTerminalProxy) readPump() {
|
||||
defer p.shutdown()
|
||||
for {
|
||||
_, raw, err := p.conn.ReadMessage()
|
||||
if err != nil {
|
||||
if !isClosedConnError(err) {
|
||||
msg := fmt.Sprintf("[multica] websocket closed: %v", err)
|
||||
p.exitMsg.CompareAndSwap(nil, &msg)
|
||||
}
|
||||
return
|
||||
}
|
||||
var env protocol.Message
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
continue
|
||||
}
|
||||
switch env.Type {
|
||||
case protocol.MessageTypeTerminalData:
|
||||
var pl protocol.TerminalDataPayload
|
||||
if err := json.Unmarshal(env.Payload, &pl); err != nil {
|
||||
continue
|
||||
}
|
||||
data, err := base64.StdEncoding.DecodeString(pl.DataB64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, _ = p.stdout.Write(data)
|
||||
case protocol.MessageTypeTerminalExit:
|
||||
var pl protocol.TerminalExitPayload
|
||||
if err := json.Unmarshal(env.Payload, &pl); err != nil {
|
||||
continue
|
||||
}
|
||||
reason := pl.Reason
|
||||
if reason == "" {
|
||||
reason = "child exited"
|
||||
}
|
||||
msg := fmt.Sprintf("\r\n[multica] %s (exit code %d)", reason, pl.ExitCode)
|
||||
p.exitMsg.CompareAndSwap(nil, &msg)
|
||||
if pl.ExitCode > 0 {
|
||||
p.exitCode.Store(int32(pl.ExitCode))
|
||||
}
|
||||
return
|
||||
case protocol.MessageTypeTerminalError:
|
||||
var pl protocol.TerminalErrorPayload
|
||||
if err := json.Unmarshal(env.Payload, &pl); err != nil {
|
||||
continue
|
||||
}
|
||||
msg := fmt.Sprintf("\r\n[multica] error: %s (%s)", pl.Message, pl.Code)
|
||||
p.exitMsg.CompareAndSwap(nil, &msg)
|
||||
p.exitCode.Store(1)
|
||||
return
|
||||
case protocol.MessageTypeTerminalClose:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stdinPump reads stdin, runs it through the escape-sequence state machine,
|
||||
// and forwards bytes as terminal.data frames. Detach (~.) closes the WS
|
||||
// without sending the bytes.
|
||||
func (p *cliTerminalProxy) stdinPump(rawTTY bool) {
|
||||
defer p.shutdown()
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
// Start in newline state so the very first character can trigger an
|
||||
// escape sequence; mirrors ssh's behavior.
|
||||
state := newlineState{atNewline: true}
|
||||
for {
|
||||
n, err := p.stdin.Read(buf)
|
||||
if n > 0 {
|
||||
toSend, detach := state.process(buf[:n], p.escapeChar)
|
||||
if len(toSend) > 0 {
|
||||
if err := p.sendData(toSend); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if detach {
|
||||
msg := terminalDetachExitMessage
|
||||
p.exitMsg.CompareAndSwap(nil, &msg)
|
||||
_ = p.sendCloseBestEffort("client_detach")
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
msg := fmt.Sprintf("[multica] stdin error: %v", err)
|
||||
p.exitMsg.CompareAndSwap(nil, &msg)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *cliTerminalProxy) sendData(data []byte) error {
|
||||
sid := p.SessionID()
|
||||
if sid == "" {
|
||||
return errors.New("session_id not set")
|
||||
}
|
||||
frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalData, protocol.TerminalDataPayload{
|
||||
SessionID: sid,
|
||||
DataB64: base64.StdEncoding.EncodeToString(data),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.writeRawFrame(frame)
|
||||
}
|
||||
|
||||
func (p *cliTerminalProxy) sendResize(cols, rows uint16) error {
|
||||
sid := p.SessionID()
|
||||
if sid == "" {
|
||||
// Pre-handshake resize is sent later by run() once session is known.
|
||||
return nil
|
||||
}
|
||||
frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalResize, protocol.TerminalResizePayload{
|
||||
SessionID: sid,
|
||||
Cols: cols,
|
||||
Rows: rows,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.writeRawFrame(frame)
|
||||
}
|
||||
|
||||
func (p *cliTerminalProxy) sendCloseBestEffort(reason string) error {
|
||||
sid := p.SessionID()
|
||||
if sid == "" {
|
||||
return nil
|
||||
}
|
||||
frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalClose, protocol.TerminalClosePayload{
|
||||
SessionID: sid,
|
||||
Reason: reason,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.writeRawFrame(frame)
|
||||
}
|
||||
|
||||
func (p *cliTerminalProxy) writeRawFrame(frame []byte) error {
|
||||
p.writeMu.Lock()
|
||||
defer p.writeMu.Unlock()
|
||||
if err := p.conn.SetWriteDeadline(time.Now().Add(terminalServerWriteWait)); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.conn.WriteMessage(websocket.TextMessage, frame)
|
||||
}
|
||||
|
||||
func (p *cliTerminalProxy) SessionID() string {
|
||||
p.sessionMu.RLock()
|
||||
defer p.sessionMu.RUnlock()
|
||||
return p.sessionID
|
||||
}
|
||||
|
||||
func (p *cliTerminalProxy) setSessionID(sid string) {
|
||||
p.sessionMu.Lock()
|
||||
defer p.sessionMu.Unlock()
|
||||
p.sessionID = sid
|
||||
}
|
||||
|
||||
func (p *cliTerminalProxy) shutdown() {
|
||||
p.closeOnce.Do(func() {
|
||||
close(p.doneCh)
|
||||
_ = p.conn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func marshalCLITerminalFrame(msgType string, payload any) ([]byte, error) {
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(protocol.Message{Type: msgType, Payload: raw})
|
||||
}
|
||||
|
||||
func isClosedConnError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// --- escape sequence state machine -----------------------------------------
|
||||
//
|
||||
// Mirrors ssh(1)'s `~.` detach: after a newline, a single escape character
|
||||
// followed by `.` detaches; `~~` emits a literal escape; `~?` prints help;
|
||||
// any other byte aborts the escape and forwards both bytes.
|
||||
|
||||
type newlineState struct {
|
||||
atNewline bool
|
||||
gotEscape bool
|
||||
}
|
||||
|
||||
// process consumes a chunk of stdin bytes. Returns the bytes that should
|
||||
// actually be forwarded to the daemon and whether the user requested detach.
|
||||
// The state machine mutates the receiver across calls so multi-byte chunks
|
||||
// straddling escape boundaries (rare, but possible with paste) work.
|
||||
func (s *newlineState) process(in []byte, escape byte) (out []byte, detach bool) {
|
||||
if escape == 0 {
|
||||
// Escape detection disabled — pass through.
|
||||
return in, false
|
||||
}
|
||||
out = make([]byte, 0, len(in))
|
||||
for _, b := range in {
|
||||
switch {
|
||||
case s.gotEscape:
|
||||
s.gotEscape = false
|
||||
switch b {
|
||||
case '.':
|
||||
return out, true
|
||||
case escape:
|
||||
out = append(out, escape)
|
||||
s.atNewline = false
|
||||
case '?':
|
||||
// Help is a local-only signal — not delivered to PTY.
|
||||
// Caller can detect by … actually keep it simple: just
|
||||
// emit a CR for visual feedback so the prompt redraws.
|
||||
out = append(out, '\r')
|
||||
s.atNewline = true
|
||||
default:
|
||||
// Not a recognized escape: forward ESC then this byte.
|
||||
out = append(out, escape, b)
|
||||
s.atNewline = b == '\r' || b == '\n'
|
||||
}
|
||||
case s.atNewline && b == escape:
|
||||
s.gotEscape = true
|
||||
default:
|
||||
out = append(out, b)
|
||||
s.atNewline = b == '\r' || b == '\n'
|
||||
}
|
||||
}
|
||||
return out, false
|
||||
}
|
||||
|
||||
526
server/cmd/multica/cmd_issue_terminal_test.go
Normal file
526
server/cmd/multica/cmd_issue_terminal_test.go
Normal file
@@ -0,0 +1,526 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/multica-ai/multica/server/pkg/protocol"
|
||||
)
|
||||
|
||||
func TestEscapeState_DetachOnFreshLine(t *testing.T) {
|
||||
s := &newlineState{atNewline: true}
|
||||
out, detach := s.process([]byte("~."), '~')
|
||||
if !detach {
|
||||
t.Fatalf("expected detach")
|
||||
}
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected no bytes forwarded, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapeState_TildeNotAfterNewlineIsLiteral(t *testing.T) {
|
||||
s := &newlineState{atNewline: false}
|
||||
out, detach := s.process([]byte("foo~.bar"), '~')
|
||||
if detach {
|
||||
t.Fatalf("must not detach when ~ is mid-line")
|
||||
}
|
||||
if string(out) != "foo~.bar" {
|
||||
t.Fatalf("got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapeState_DoubleTildeEmitsLiteral(t *testing.T) {
|
||||
s := &newlineState{atNewline: true}
|
||||
out, detach := s.process([]byte("~~"), '~')
|
||||
if detach {
|
||||
t.Fatalf("~~ must not detach")
|
||||
}
|
||||
if string(out) != "~" {
|
||||
t.Fatalf("got %q want ~", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapeState_StraddledChunks(t *testing.T) {
|
||||
// User pastes/types ~ and . in two separate stdin reads — escape
|
||||
// detection still works because state is preserved across calls.
|
||||
s := &newlineState{atNewline: true}
|
||||
out1, detach1 := s.process([]byte("~"), '~')
|
||||
if detach1 || len(out1) != 0 {
|
||||
t.Fatalf("first chunk: detach=%v out=%q", detach1, out1)
|
||||
}
|
||||
out2, detach2 := s.process([]byte("."), '~')
|
||||
if !detach2 {
|
||||
t.Fatalf("expected detach on second chunk")
|
||||
}
|
||||
if len(out2) != 0 {
|
||||
t.Fatalf("second chunk should forward nothing, got %q", out2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapeState_DisabledWhenEscapeIsZero(t *testing.T) {
|
||||
s := &newlineState{atNewline: true}
|
||||
out, detach := s.process([]byte("~."), 0)
|
||||
if detach {
|
||||
t.Fatalf("disabled escape must not detach")
|
||||
}
|
||||
if string(out) != "~." {
|
||||
t.Fatalf("got %q want ~.", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEscapeState_UnknownEscapeForwardsBoth(t *testing.T) {
|
||||
s := &newlineState{atNewline: true}
|
||||
out, _ := s.process([]byte("~x"), '~')
|
||||
if string(out) != "~x" {
|
||||
t.Fatalf("got %q want ~x", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildTerminalPathAndQuery(t *testing.T) {
|
||||
got := buildTerminalPathAndQuery("MUL-2295", "ws-uuid", 120, 40)
|
||||
u, err := url.Parse("http://x" + got)
|
||||
if err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if u.Path != "/ws/issues/MUL-2295/terminal" {
|
||||
t.Errorf("path = %q", u.Path)
|
||||
}
|
||||
q := u.Query()
|
||||
if q.Get("workspace_id") != "ws-uuid" {
|
||||
t.Errorf("workspace_id = %q", q.Get("workspace_id"))
|
||||
}
|
||||
if q.Get("cols") != "120" {
|
||||
t.Errorf("cols = %q", q.Get("cols"))
|
||||
}
|
||||
if q.Get("rows") != "40" {
|
||||
t.Errorf("rows = %q", q.Get("rows"))
|
||||
}
|
||||
}
|
||||
|
||||
// fakeServer simulates the Phase 2 /ws/issues/{id}/terminal handshake plus
|
||||
// a tiny echo loop, so we can drive the CLI proxy through its full lifecycle
|
||||
// in-process without spinning up the real daemon.
|
||||
type fakeServer struct {
|
||||
t *testing.T
|
||||
upgrader websocket.Upgrader
|
||||
gotAuth chan string
|
||||
gotData chan []byte
|
||||
gotClose chan string
|
||||
sessionID string
|
||||
server *httptest.Server
|
||||
connMu sync.Mutex
|
||||
conn *websocket.Conn
|
||||
sendOpenErr *protocol.TerminalErrorPayload // if set, send terminal.error instead of terminal.opened
|
||||
}
|
||||
|
||||
// writeFrame serializes writes from the handler goroutine and any test
|
||||
// goroutine that wants to push a frame to the connected client. Required
|
||||
// because gorilla/websocket allows concurrent read+write but NOT concurrent
|
||||
// writes from different goroutines.
|
||||
func (fs *fakeServer) writeFrame(frame []byte) error {
|
||||
fs.connMu.Lock()
|
||||
defer fs.connMu.Unlock()
|
||||
if fs.conn == nil {
|
||||
return fmt.Errorf("no client")
|
||||
}
|
||||
return fs.conn.WriteMessage(websocket.TextMessage, frame)
|
||||
}
|
||||
|
||||
func newFakeServer(t *testing.T) *fakeServer {
|
||||
fs := &fakeServer{
|
||||
t: t,
|
||||
upgrader: websocket.Upgrader{},
|
||||
gotAuth: make(chan string, 1),
|
||||
gotData: make(chan []byte, 32),
|
||||
gotClose: make(chan string, 1),
|
||||
sessionID: "session-xyz",
|
||||
}
|
||||
fs.server = httptest.NewServer(http.HandlerFunc(fs.handle))
|
||||
return fs
|
||||
}
|
||||
|
||||
func (fs *fakeServer) close() {
|
||||
fs.connMu.Lock()
|
||||
c := fs.conn
|
||||
fs.connMu.Unlock()
|
||||
if c != nil {
|
||||
c.Close()
|
||||
}
|
||||
fs.server.Close()
|
||||
}
|
||||
|
||||
func (fs *fakeServer) baseURL() string { return fs.server.URL }
|
||||
|
||||
func (fs *fakeServer) handle(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := fs.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fs.t.Errorf("upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
fs.connMu.Lock()
|
||||
fs.conn = conn
|
||||
fs.connMu.Unlock()
|
||||
|
||||
// 1. Auth.
|
||||
_, raw, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var auth struct {
|
||||
Type string `json:"type"`
|
||||
Payload map[string]any `json:"payload"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &auth); err != nil || auth.Type != "auth" {
|
||||
_ = fs.writeFrame([]byte(`{"error":"bad auth"}`))
|
||||
return
|
||||
}
|
||||
tok, _ := auth.Payload["token"].(string)
|
||||
fs.gotAuth <- tok
|
||||
_ = fs.writeFrame([]byte(`{"type":"auth_ack"}`))
|
||||
|
||||
// 2. Open ack.
|
||||
if fs.sendOpenErr != nil {
|
||||
ep := *fs.sendOpenErr
|
||||
frame, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalError, ep)
|
||||
_ = fs.writeFrame(frame)
|
||||
return
|
||||
}
|
||||
openedFrame, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalOpened, protocol.TerminalOpenedPayload{
|
||||
SessionID: fs.sessionID,
|
||||
WorkDir: "/tmp/work",
|
||||
Shell: "/bin/bash",
|
||||
})
|
||||
_ = fs.writeFrame(openedFrame)
|
||||
|
||||
// 3. Pump.
|
||||
for {
|
||||
_, raw, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var env protocol.Message
|
||||
if err := json.Unmarshal(raw, &env); err != nil {
|
||||
continue
|
||||
}
|
||||
switch env.Type {
|
||||
case protocol.MessageTypeTerminalData:
|
||||
var pl protocol.TerminalDataPayload
|
||||
if err := json.Unmarshal(env.Payload, &pl); err != nil {
|
||||
continue
|
||||
}
|
||||
data, _ := base64.StdEncoding.DecodeString(pl.DataB64)
|
||||
fs.gotData <- data
|
||||
// Echo back so the CLI's stdout pump has something to do.
|
||||
echo, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalData, protocol.TerminalDataPayload{
|
||||
SessionID: fs.sessionID,
|
||||
DataB64: pl.DataB64,
|
||||
})
|
||||
_ = fs.writeFrame(echo)
|
||||
case protocol.MessageTypeTerminalClose:
|
||||
var pl protocol.TerminalClosePayload
|
||||
_ = json.Unmarshal(env.Payload, &pl)
|
||||
fs.gotClose <- pl.Reason
|
||||
return
|
||||
case protocol.MessageTypeTerminalResize:
|
||||
// observed but unused in this fake
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTestCmd() *cobra.Command {
|
||||
c := &cobra.Command{}
|
||||
c.Flags().String("escape-char", "~", "")
|
||||
c.Flags().Bool("no-raw", true, "")
|
||||
return c
|
||||
}
|
||||
|
||||
func TestCLITerminalProxy_HandshakeAndEcho(t *testing.T) {
|
||||
fs := newFakeServer(t)
|
||||
defer fs.close()
|
||||
|
||||
wsURL := strings.Replace(fs.baseURL(), "http://", "ws://", 1) + "/"
|
||||
dialer := *websocket.DefaultDialer
|
||||
conn, _, err := dialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
|
||||
stdinR, stdinW := io.Pipe()
|
||||
stdout := newSafeBuffer()
|
||||
stderr := newSafeBuffer()
|
||||
|
||||
cmd := newTestCmd()
|
||||
p := newCLITerminalProxy(conn, stdinR, stdout, stderr, "mul_test", cmd)
|
||||
|
||||
// Drive handshake explicitly so we can also assert the auth token reached
|
||||
// the fake server.
|
||||
if err := p.handshake(); err != nil {
|
||||
t.Fatalf("handshake: %v", err)
|
||||
}
|
||||
select {
|
||||
case got := <-fs.gotAuth:
|
||||
if got != "mul_test" {
|
||||
t.Errorf("auth token = %q, want mul_test", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("server did not receive auth frame")
|
||||
}
|
||||
if p.SessionID() != fs.sessionID {
|
||||
t.Fatalf("session_id = %q, want %q", p.SessionID(), fs.sessionID)
|
||||
}
|
||||
|
||||
// Now run the pumps in a goroutine.
|
||||
pumpsDone := make(chan struct{})
|
||||
go func() {
|
||||
go p.readPump()
|
||||
p.stdinPump(false)
|
||||
close(pumpsDone)
|
||||
}()
|
||||
|
||||
// Send "hello" through stdin; expect server to receive it and echo it
|
||||
// back into stdout.
|
||||
if _, err := stdinW.Write([]byte("hello")); err != nil {
|
||||
t.Fatalf("stdin write: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-fs.gotData:
|
||||
if string(got) != "hello" {
|
||||
t.Fatalf("server got %q, want hello", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("server did not receive data")
|
||||
}
|
||||
|
||||
// Wait for the echo to land in stdout.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if strings.Contains(stdout.String(), "hello") {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if !strings.Contains(stdout.String(), "hello") {
|
||||
t.Fatalf("stdout missing echo, got %q", stdout.String())
|
||||
}
|
||||
|
||||
// Trigger detach: send "\n~." after a newline. Because stdinPump starts
|
||||
// the state machine at atNewline=true on the very first byte, we need
|
||||
// to walk through a real newline first to make the test realistic.
|
||||
if _, err := stdinW.Write([]byte("\n~.")); err != nil {
|
||||
t.Fatalf("stdin write detach: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-pumpsDone:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("stdin pump did not exit after detach")
|
||||
}
|
||||
|
||||
select {
|
||||
case reason := <-fs.gotClose:
|
||||
if reason != "client_detach" {
|
||||
t.Errorf("close reason = %q, want client_detach", reason)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("server did not receive terminal.close on detach")
|
||||
}
|
||||
|
||||
// run() prints the exit message to stderr; in this lower-level test we
|
||||
// drive the pumps directly, so check the captured exit message.
|
||||
msgPtr := p.exitMsg.Load()
|
||||
if msgPtr == nil || !strings.Contains(*msgPtr, "detached") {
|
||||
got := ""
|
||||
if msgPtr != nil {
|
||||
got = *msgPtr
|
||||
}
|
||||
t.Errorf("exit msg = %q, want detach text", got)
|
||||
}
|
||||
}
|
||||
|
||||
// safeBuffer is a tiny mutex-wrapped bytes.Buffer for tests that read from
|
||||
// the buffer in one goroutine while another writes (race-detector-clean).
|
||||
type safeBuffer struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func newSafeBuffer() *safeBuffer { return &safeBuffer{} }
|
||||
|
||||
func (b *safeBuffer) Write(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buf.Write(p)
|
||||
}
|
||||
|
||||
func (b *safeBuffer) String() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buf.String()
|
||||
}
|
||||
|
||||
func TestCLITerminalProxy_HandshakeRejectedOnTerminalError(t *testing.T) {
|
||||
fs := newFakeServer(t)
|
||||
fs.sendOpenErr = &protocol.TerminalErrorPayload{
|
||||
Code: protocol.TerminalErrorCodeTaskNotFound,
|
||||
Message: "no agent task on this issue",
|
||||
}
|
||||
defer fs.close()
|
||||
|
||||
wsURL := strings.Replace(fs.baseURL(), "http://", "ws://", 1) + "/"
|
||||
dialer := *websocket.DefaultDialer
|
||||
conn, _, err := dialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
|
||||
cmd := newTestCmd()
|
||||
p := newCLITerminalProxy(conn, strings.NewReader(""), io.Discard, io.Discard, "mul_test", cmd)
|
||||
err = p.handshake()
|
||||
if err == nil {
|
||||
t.Fatal("expected handshake error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), protocol.TerminalErrorCodeTaskNotFound) {
|
||||
t.Errorf("error %q does not mention error code", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCLITerminalProxy_AuthRejected(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
// Read auth frame, reply with error.
|
||||
_, _, _ = conn.ReadMessage()
|
||||
_ = conn.WriteMessage(websocket.TextMessage, []byte(`{"error":"invalid token"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := strings.Replace(server.URL, "http://", "ws://", 1) + "/"
|
||||
dialer := *websocket.DefaultDialer
|
||||
conn, _, err := dialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
|
||||
cmd := newTestCmd()
|
||||
p := newCLITerminalProxy(conn, strings.NewReader(""), io.Discard, io.Discard, "mul_test", cmd)
|
||||
err = p.handshake()
|
||||
if err == nil {
|
||||
t.Fatal("expected handshake error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid token") {
|
||||
t.Errorf("error %q does not surface server reason", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCLITerminalProxy_TerminalExitDeliversCode(t *testing.T) {
|
||||
// Driver: open server, advance through handshake, then push a
|
||||
// terminal.exit frame and verify the proxy's exit code state.
|
||||
fs := newFakeServer(t)
|
||||
defer fs.close()
|
||||
|
||||
wsURL := strings.Replace(fs.baseURL(), "http://", "ws://", 1) + "/"
|
||||
dialer := *websocket.DefaultDialer
|
||||
conn, _, err := dialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
|
||||
cmd := newTestCmd()
|
||||
p := newCLITerminalProxy(conn, strings.NewReader(""), io.Discard, io.Discard, "mul_test", cmd)
|
||||
if err := p.handshake(); err != nil {
|
||||
t.Fatalf("handshake: %v", err)
|
||||
}
|
||||
|
||||
exitFrame, _ := marshalCLITerminalFrame(protocol.MessageTypeTerminalExit, protocol.TerminalExitPayload{
|
||||
SessionID: fs.sessionID,
|
||||
ExitCode: 42,
|
||||
Reason: "child exited",
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
p.readPump()
|
||||
}()
|
||||
if err := fs.writeFrame(exitFrame); err != nil {
|
||||
t.Fatalf("server write exit: %v", err)
|
||||
}
|
||||
|
||||
doneAt := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(doneAt) {
|
||||
if p.exitCode.Load() == 42 {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
wg.Wait()
|
||||
if got := p.exitCode.Load(); got != 42 {
|
||||
t.Fatalf("exit code = %d, want 42", got)
|
||||
}
|
||||
msgPtr := p.exitMsg.Load()
|
||||
if msgPtr == nil || !strings.Contains(*msgPtr, "exit code 42") {
|
||||
got := ""
|
||||
if msgPtr != nil {
|
||||
got = *msgPtr
|
||||
}
|
||||
t.Errorf("exit msg = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Compile-time check: ensure the marshaled frame round-trips through the
|
||||
// real protocol.Message envelope. Catches any drift if the protocol pkg
|
||||
// renames a field.
|
||||
func TestMarshalCLITerminalFrame_EnvelopeShape(t *testing.T) {
|
||||
frame, err := marshalCLITerminalFrame(protocol.MessageTypeTerminalResize, protocol.TerminalResizePayload{
|
||||
SessionID: "sid",
|
||||
Cols: 100,
|
||||
Rows: 30,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var env protocol.Message
|
||||
if err := json.Unmarshal(frame, &env); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if env.Type != protocol.MessageTypeTerminalResize {
|
||||
t.Fatalf("type = %q", env.Type)
|
||||
}
|
||||
var pl protocol.TerminalResizePayload
|
||||
if err := json.Unmarshal(env.Payload, &pl); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if pl.Cols != 100 || pl.Rows != 30 || pl.SessionID != "sid" {
|
||||
t.Fatalf("payload = %+v", pl)
|
||||
}
|
||||
}
|
||||
|
||||
// Sanity check the help string does not crash on a zero escape byte.
|
||||
func TestEscapeHelpString(t *testing.T) {
|
||||
if got := escapeHelpString(0); got != "(disabled)" {
|
||||
t.Errorf("escape disabled hint = %q", got)
|
||||
}
|
||||
if got := escapeHelpString('~'); !strings.Contains(got, "~") {
|
||||
t.Errorf("escape help = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
40
server/cmd/multica/cmd_issue_terminal_unix.go
Normal file
40
server/cmd/multica/cmd_issue_terminal_unix.go
Normal file
@@ -0,0 +1,40 @@
|
||||
//go:build !windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// startResizeWatcher installs a SIGWINCH handler that pushes the new local
|
||||
// terminal size to the daemon every time the user resizes their window.
|
||||
// Returns a stop function that uninstalls the handler and exits the
|
||||
// goroutine. On platforms without SIGWINCH (Windows) the windows-tagged
|
||||
// implementation polls term.GetSize on a timer instead.
|
||||
func startResizeWatcher(p *cliTerminalProxy) func() {
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGWINCH)
|
||||
stop := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
if c, r, err := term.GetSize(int(os.Stdout.Fd())); err == nil && c > 0 && r > 0 {
|
||||
_ = p.sendResize(uint16(c), uint16(r))
|
||||
}
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return func() {
|
||||
signal.Stop(ch)
|
||||
close(stop)
|
||||
}
|
||||
}
|
||||
39
server/cmd/multica/cmd_issue_terminal_windows.go
Normal file
39
server/cmd/multica/cmd_issue_terminal_windows.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// startResizeWatcher polls the local terminal size on a timer, since
|
||||
// Windows has no SIGWINCH equivalent that is reliable for console resize
|
||||
// events. 500ms is a compromise between responsiveness and CPU cost.
|
||||
func startResizeWatcher(p *cliTerminalProxy) func() {
|
||||
stop := make(chan struct{})
|
||||
go func() {
|
||||
var lastC, lastR int
|
||||
t := time.NewTicker(500 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-t.C:
|
||||
c, r, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil || c <= 0 || r <= 0 {
|
||||
continue
|
||||
}
|
||||
if c == lastC && r == lastR {
|
||||
continue
|
||||
}
|
||||
lastC, lastR = c, r
|
||||
_ = p.sendResize(uint16(c), uint16(r))
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() { close(stop) }
|
||||
}
|
||||
@@ -24,6 +24,7 @@ require (
|
||||
github.com/resend/resend-go/v2 v2.28.0
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/spf13/cobra v1.10.2
|
||||
golang.org/x/term v0.43.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -58,7 +59,7 @@ require (
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/sys v0.44.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
)
|
||||
|
||||
@@ -137,8 +137,10 @@ go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
||||
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
|
||||
97
server/internal/cli/ws.go
Normal file
97
server/internal/cli/ws.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// DialWebSocket opens a WebSocket connection to the server at the given path
|
||||
// + query string. The path must start with "/". Auth is intentionally NOT
|
||||
// sent as a header here: the server's terminal endpoint runs WS upgrade
|
||||
// before applying header-based auth middleware (browsers cannot set
|
||||
// Authorization on a WS upgrade), so the caller authenticates via the
|
||||
// first-frame `auth` message instead. The standard X-Workspace-ID /
|
||||
// X-Client-* identity headers are still attached so dashboards can attribute
|
||||
// the connection to the right CLI build.
|
||||
func (c *APIClient) DialWebSocket(ctx context.Context, pathAndQuery string) (*websocket.Conn, *http.Response, error) {
|
||||
if c.BaseURL == "" {
|
||||
return nil, nil, fmt.Errorf("APIClient has no BaseURL")
|
||||
}
|
||||
wsURL, err := httpToWSURL(c.BaseURL, pathAndQuery)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
header := http.Header{}
|
||||
c.setWSHeaders(header)
|
||||
|
||||
dialer := *websocket.DefaultDialer
|
||||
conn, resp, err := dialer.DialContext(ctx, wsURL, header)
|
||||
if err != nil {
|
||||
return nil, resp, err
|
||||
}
|
||||
return conn, resp, nil
|
||||
}
|
||||
|
||||
// setWSHeaders attaches identity headers but deliberately omits the
|
||||
// Authorization header. Auth happens in-band via the first frame so this
|
||||
// stays consistent with cookie-based browser clients.
|
||||
func (c *APIClient) setWSHeaders(h http.Header) {
|
||||
if c.WorkspaceID != "" {
|
||||
h.Set("X-Workspace-ID", c.WorkspaceID)
|
||||
}
|
||||
platform := c.Platform
|
||||
if platform == "" {
|
||||
platform = ClientPlatform
|
||||
}
|
||||
if platform != "" {
|
||||
h.Set("X-Client-Platform", platform)
|
||||
}
|
||||
version := c.Version
|
||||
if version == "" {
|
||||
version = ClientVersion
|
||||
}
|
||||
if version != "" {
|
||||
h.Set("X-Client-Version", version)
|
||||
}
|
||||
osName := c.OS
|
||||
if osName == "" {
|
||||
osName = ClientOS
|
||||
}
|
||||
if osName != "" {
|
||||
h.Set("X-Client-OS", osName)
|
||||
}
|
||||
}
|
||||
|
||||
func httpToWSURL(baseURL, pathAndQuery string) (string, error) {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse base URL: %w", err)
|
||||
}
|
||||
switch strings.ToLower(u.Scheme) {
|
||||
case "http":
|
||||
u.Scheme = "ws"
|
||||
case "https":
|
||||
u.Scheme = "wss"
|
||||
case "ws", "wss":
|
||||
// already WS
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported base URL scheme %q", u.Scheme)
|
||||
}
|
||||
if !strings.HasPrefix(pathAndQuery, "/") {
|
||||
return "", fmt.Errorf("path must start with /, got %q", pathAndQuery)
|
||||
}
|
||||
suffix, err := url.Parse(pathAndQuery)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse path/query: %w", err)
|
||||
}
|
||||
u.Path = strings.TrimRight(u.Path, "/") + suffix.Path
|
||||
u.RawQuery = suffix.RawQuery
|
||||
u.Fragment = ""
|
||||
return u.String(), nil
|
||||
}
|
||||
125
server/internal/cli/ws_test.go
Normal file
125
server/internal/cli/ws_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func TestHTTPToWSURL(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
base string
|
||||
path string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "https → wss",
|
||||
base: "https://api.example.com",
|
||||
path: "/ws/issues/abc/terminal?workspace_id=ws1&cols=80",
|
||||
want: "wss://api.example.com/ws/issues/abc/terminal?workspace_id=ws1&cols=80",
|
||||
},
|
||||
{
|
||||
name: "http → ws",
|
||||
base: "http://localhost:8080",
|
||||
path: "/ws/issues/x/terminal",
|
||||
want: "ws://localhost:8080/ws/issues/x/terminal",
|
||||
},
|
||||
{
|
||||
name: "wss left alone",
|
||||
base: "wss://api.example.com",
|
||||
path: "/ws",
|
||||
want: "wss://api.example.com/ws",
|
||||
},
|
||||
{
|
||||
name: "trailing slash on base preserved correctly",
|
||||
base: "https://api.example.com/",
|
||||
path: "/ws/x",
|
||||
want: "wss://api.example.com/ws/x",
|
||||
},
|
||||
{
|
||||
name: "missing leading slash on path",
|
||||
base: "https://api.example.com",
|
||||
path: "ws/x",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported scheme",
|
||||
base: "ftp://example.com",
|
||||
path: "/ws",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := httpToWSURL(tc.base, tc.path)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got %q", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Fatalf("got %q want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialWebSocketAttachesIdentityHeaders(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{}
|
||||
gotHeaders := make(chan http.Header, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotHeaders <- r.Header.Clone()
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewAPIClient(server.URL, "ws-uuid", "mul_test_token")
|
||||
client.Platform = "cli"
|
||||
client.Version = "1.2.3"
|
||||
client.OS = "macos"
|
||||
|
||||
conn, _, err := client.DialWebSocket(context.Background(), "/ws")
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
headers := <-gotHeaders
|
||||
if got := headers.Get("X-Workspace-ID"); got != "ws-uuid" {
|
||||
t.Errorf("X-Workspace-ID = %q, want ws-uuid", got)
|
||||
}
|
||||
if got := headers.Get("X-Client-Platform"); got != "cli" {
|
||||
t.Errorf("X-Client-Platform = %q, want cli", got)
|
||||
}
|
||||
if got := headers.Get("X-Client-Version"); got != "1.2.3" {
|
||||
t.Errorf("X-Client-Version = %q, want 1.2.3", got)
|
||||
}
|
||||
if got := headers.Get("X-Client-OS"); got != "macos" {
|
||||
t.Errorf("X-Client-OS = %q, want macos", got)
|
||||
}
|
||||
if got := headers.Get("Authorization"); got != "" {
|
||||
// The server's terminal endpoint runs WS upgrade before any header
|
||||
// auth middleware, so the CLI must authenticate via the first frame
|
||||
// to match cookie-based browser clients. Sending a Bearer header
|
||||
// here would silently work in some setups and silently fail in
|
||||
// others — keep it consistent and absent.
|
||||
t.Errorf("Authorization header should NOT be set on WS dial, got %q", got)
|
||||
}
|
||||
if got := headers.Get("Sec-WebSocket-Key"); !strings.HasPrefix(strings.TrimSpace(got), "") || got == "" {
|
||||
t.Errorf("Sec-WebSocket-Key missing")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user