mirror of
https://github.com/multica-ai/multica.git
synced 2026-06-26 17:09:14 +02:00
Compare commits
2 Commits
agent/lamb
...
agent/j/c4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a7c6cbdf16 | ||
|
|
c706c7e744 |
@@ -387,6 +387,7 @@ func NewRouterWithOptions(pool *pgxpool.Pool, hub *realtime.Hub, bus *events.Bus
|
||||
r.Route("/api/tokens", func(r chi.Router) {
|
||||
r.Get("/", h.ListPersonalAccessTokens)
|
||||
r.Post("/", h.CreatePersonalAccessToken)
|
||||
r.Post("/current/renew", h.RenewCurrentPersonalAccessToken)
|
||||
r.Delete("/{id}", h.RevokePersonalAccessToken)
|
||||
})
|
||||
|
||||
|
||||
@@ -55,6 +55,17 @@ func isTaskNotFoundError(err error) bool {
|
||||
return strings.Contains(strings.ToLower(reqErr.Body), "task not found")
|
||||
}
|
||||
|
||||
// isUnauthorizedError returns true if the error is a 401 from the server.
|
||||
// Used by the token-renewal loop to surface a clear "re-login required"
|
||||
// message instead of a generic transport-level retry.
|
||||
func isUnauthorizedError(err error) bool {
|
||||
var reqErr *requestError
|
||||
if !errors.As(err, &reqErr) {
|
||||
return false
|
||||
}
|
||||
return reqErr.StatusCode == http.StatusUnauthorized
|
||||
}
|
||||
|
||||
// isRuntimeNotFoundError returns true if the error is a 404 with "runtime not
|
||||
// found" body. The daemon uses this to detect that the runtime row was deleted
|
||||
// server-side (UI Delete, 7-day offline GC) while the daemon was still
|
||||
@@ -315,6 +326,27 @@ type WorkspaceInfo struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// RenewTokenResponse mirrors handler.RenewPATResponse — kept loose (string +
|
||||
// bool) because the daemon never parses the timestamp itself; it just logs it
|
||||
// for operator visibility.
|
||||
type RenewTokenResponse struct {
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
Renewed bool `json:"renewed"`
|
||||
}
|
||||
|
||||
// RenewToken asks the server to extend the daemon's current PAT in place when
|
||||
// it's within the server-side renewal window. The server is authoritative on
|
||||
// the threshold — the daemon doesn't know the token's expires_at locally —
|
||||
// so this is safe to call on any cadence; the only thing extra calls cost is
|
||||
// one round trip and one cheap SELECT.
|
||||
func (c *Client) RenewToken(ctx context.Context) (*RenewTokenResponse, error) {
|
||||
var resp RenewTokenResponse
|
||||
if err := c.postJSON(ctx, "/api/tokens/current/renew", map[string]any{}, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListWorkspaces fetches all workspaces the authenticated user belongs to.
|
||||
func (c *Client) ListWorkspaces(ctx context.Context) ([]WorkspaceInfo, error) {
|
||||
var workspaces []WorkspaceInfo
|
||||
|
||||
@@ -619,13 +619,12 @@ func (d *Daemon) Run(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Fetch all user workspaces from the API and register runtimes for any
|
||||
// that exist. Zero workspaces is a valid state — a newly-signed-up user
|
||||
// may start the daemon before creating their first workspace. The
|
||||
// workspaceSyncLoop below polls every 30s and will register runtimes
|
||||
// when a workspace appears, so the daemon stays useful as a long-lived
|
||||
// background process rather than crashing at startup.
|
||||
if err := d.syncWorkspacesFromAPI(ctx); err != nil {
|
||||
// Renew the PAT before the first API call, then do the initial
|
||||
// workspace sync. Both steps live in preflightAuth so the ordering
|
||||
// invariant (renew first) is enforced at one site instead of
|
||||
// scattered into Run, and tests can exercise the failure paths
|
||||
// without the full Run setup.
|
||||
if err := d.preflightAuth(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -640,8 +639,9 @@ func (d *Daemon) Run(ctx context.Context) error {
|
||||
go d.heartbeatLoop(ctx)
|
||||
go d.gcLoop(ctx)
|
||||
go d.autoUpdateLoop(ctx)
|
||||
go d.tokenRenewalLoop(ctx)
|
||||
go d.serveHealth(ctx, healthLn, time.Now())
|
||||
d.logger.Debug("background loops launched (workspace-sync, task-wakeup, heartbeat, gc, auto-update, health)")
|
||||
d.logger.Debug("background loops launched (workspace-sync, task-wakeup, heartbeat, gc, auto-update, token-renewal, health)")
|
||||
err = d.pollLoop(ctx, taskWakeups)
|
||||
d.logger.Debug("daemon main loop returning", "error", err)
|
||||
return err
|
||||
@@ -1025,6 +1025,89 @@ func (d *Daemon) ensureRepoReady(ctx context.Context, workspaceID, repoURL strin
|
||||
return fmt.Errorf("repo is configured but not synced")
|
||||
}
|
||||
|
||||
// DefaultTokenRenewalInterval is how often the daemon asks the server to
|
||||
// extend its PAT. The server-side threshold is 7 days of remaining lifetime;
|
||||
// polling every ~3 days gives at least two chances to renew before the
|
||||
// window closes, so a single failed call (network blip, server restart) does
|
||||
// not push the token out of the renewal window.
|
||||
const DefaultTokenRenewalInterval = 3 * 24 * time.Hour
|
||||
|
||||
// preflightAuth runs the two auth-sensitive startup steps in their
|
||||
// required order: a synchronous PAT renewal first, then the initial
|
||||
// workspace sync. The order matters — running tryRenewToken before any
|
||||
// other API call is what surfaces a user-actionable "run multica login"
|
||||
// WARN when the PAT is already revoked or expired. If we let the
|
||||
// workspace sync go first, its 401 would short-circuit Run before the
|
||||
// renewal loop's first tick ever fires, and the operator would see only
|
||||
// a generic auth failure in the workspace-sync log with no hint that
|
||||
// re-login is the fix.
|
||||
//
|
||||
// The renewal is best-effort: tryRenewToken logs and returns, never
|
||||
// propagating errors. preflightAuth's exit status is driven entirely by
|
||||
// the workspace sync — so a transient renewal failure (network blip,
|
||||
// 500) does not by itself block startup. A successful sync with zero
|
||||
// workspaces is fine: a newly-signed-up user may start the daemon
|
||||
// before creating their first workspace, and workspaceSyncLoop will
|
||||
// register runtimes once one appears.
|
||||
func (d *Daemon) preflightAuth(ctx context.Context) error {
|
||||
d.tryRenewToken(ctx)
|
||||
return d.syncWorkspacesFromAPI(ctx)
|
||||
}
|
||||
|
||||
// tokenRenewalLoop keeps the daemon's PAT alive by periodically asking the
|
||||
// server to extend its expires_at in-place. The startup renewal happens
|
||||
// synchronously in preflightAuth so a daemon coming back online after a
|
||||
// week of downtime gets a fresh expiry before its next heartbeat could
|
||||
// 401; this loop owns the long-running ~3-day cadence after that.
|
||||
//
|
||||
// The server is authoritative on the renewal threshold (it sees expires_at;
|
||||
// we don't), so this loop is intentionally dumb: call, log, sleep, repeat.
|
||||
// On 401 we surface a clear "re-login required" warning because the daemon
|
||||
// has no way to recover automatically — but we keep the loop running so the
|
||||
// user sees the same warning on every cycle until they fix it, rather than
|
||||
// silently exiting and forcing them to read scrollback to find the cause.
|
||||
func (d *Daemon) tokenRenewalLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(DefaultTokenRenewalInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
d.tryRenewToken(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryRenewToken performs one renewal round-trip with a short, isolated
|
||||
// timeout. Errors are logged but never propagated — there is no caller to
|
||||
// handle them. Failures are debug-level except for 401, which gets a
|
||||
// user-actionable warning.
|
||||
func (d *Daemon) tryRenewToken(ctx context.Context) {
|
||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := d.client.RenewToken(reqCtx)
|
||||
if err != nil {
|
||||
if isUnauthorizedError(err) {
|
||||
loginHint := "'multica login'"
|
||||
if d.cfg.Profile != "" {
|
||||
loginHint = fmt.Sprintf("'multica login --profile %s'", d.cfg.Profile)
|
||||
}
|
||||
d.logger.Warn("auth token rejected by server — run "+loginHint+" to re-authenticate, then restart the daemon", "error", err)
|
||||
return
|
||||
}
|
||||
d.logger.Debug("token renewal failed; will retry on next cycle", "error", err)
|
||||
return
|
||||
}
|
||||
if resp.Renewed {
|
||||
d.logger.Info("auth token renewed", "expires_at", resp.ExpiresAt)
|
||||
} else {
|
||||
d.logger.Debug("auth token not yet eligible for renewal", "expires_at", resp.ExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
// workspaceSyncLoop periodically fetches the user's workspaces from the API
|
||||
// and registers runtimes for any new ones.
|
||||
func (d *Daemon) workspaceSyncLoop(ctx context.Context) {
|
||||
|
||||
325
server/internal/daemon/token_renewal_test.go
Normal file
325
server/internal/daemon/token_renewal_test.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// captureLogger returns a *slog.Logger whose output lands in buf, so tests
|
||||
// can assert on the daemon's user-facing warning text without scraping
|
||||
// stderr.
|
||||
func captureLogger(buf *bytes.Buffer) *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
}
|
||||
|
||||
func TestClient_RenewToken_PostsToCorrectEndpoint(t *testing.T) {
|
||||
var called atomic.Int32
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called.Add(1)
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/api/tokens/current/renew" {
|
||||
t.Errorf("expected /api/tokens/current/renew, got %s", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer mul_abc" {
|
||||
t.Errorf("expected Bearer mul_abc, got %q", got)
|
||||
}
|
||||
// Body must be valid JSON — postJSON marshals an empty object when
|
||||
// reqBody is a non-nil map[string]any{}.
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Errorf("decode body: %v", err)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"expires_at": "2099-01-02T03:04:05Z",
|
||||
"renewed": true,
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
c := NewClient(srv.URL)
|
||||
c.SetToken("mul_abc")
|
||||
|
||||
resp, err := c.RenewToken(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("RenewToken: %v", err)
|
||||
}
|
||||
if called.Load() != 1 {
|
||||
t.Fatalf("expected 1 server call, got %d", called.Load())
|
||||
}
|
||||
if !resp.Renewed {
|
||||
t.Fatal("expected renewed=true")
|
||||
}
|
||||
if resp.ExpiresAt != "2099-01-02T03:04:05Z" {
|
||||
t.Fatalf("expected expires_at to round-trip, got %q", resp.ExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryRenewToken_LogsRenewalOnSuccess(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"expires_at": "2099-01-02T03:04:05Z",
|
||||
"renewed": true,
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
var buf bytes.Buffer
|
||||
d := &Daemon{client: NewClient(srv.URL), logger: captureLogger(&buf)}
|
||||
d.tryRenewToken(context.Background())
|
||||
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "auth token renewed") {
|
||||
t.Fatalf("expected 'auth token renewed' log, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "2099-01-02T03:04:05Z") {
|
||||
t.Fatalf("expected new expiry in log, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryRenewToken_LogsNotEligibleOnNoOp(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"expires_at": "2099-01-02T03:04:05Z",
|
||||
"renewed": false,
|
||||
})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
var buf bytes.Buffer
|
||||
d := &Daemon{client: NewClient(srv.URL), logger: captureLogger(&buf)}
|
||||
d.tryRenewToken(context.Background())
|
||||
|
||||
out := buf.String()
|
||||
// Non-renewal must NOT emit the warning that an operator would interpret
|
||||
// as "something is wrong" — it's the normal steady-state for tokens with
|
||||
// plenty of life left.
|
||||
if strings.Contains(out, "WARN") {
|
||||
t.Fatalf("no-op renewal should not log at WARN, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryRenewToken_SurfacesReloginWarningOn401(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid token"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
var buf bytes.Buffer
|
||||
d := &Daemon{client: NewClient(srv.URL), logger: captureLogger(&buf)}
|
||||
d.tryRenewToken(context.Background())
|
||||
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "level=WARN") {
|
||||
t.Fatalf("401 must surface as WARN, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "multica login") {
|
||||
t.Fatalf("401 warning must tell the user to run 'multica login', got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryRenewToken_SurfacesReloginWarningOn401_WithProfile(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid token"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
var buf bytes.Buffer
|
||||
d := &Daemon{
|
||||
client: NewClient(srv.URL),
|
||||
logger: captureLogger(&buf),
|
||||
cfg: Config{Profile: "staging"},
|
||||
}
|
||||
d.tryRenewToken(context.Background())
|
||||
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "--profile staging") {
|
||||
t.Fatalf("profile-aware login hint missing, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryRenewToken_TransientErrorIsDebugNotWarn(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"db down"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
var buf bytes.Buffer
|
||||
d := &Daemon{client: NewClient(srv.URL), logger: captureLogger(&buf)}
|
||||
d.tryRenewToken(context.Background())
|
||||
|
||||
out := buf.String()
|
||||
// A 500 is transient — the next tick will retry, so the operator should
|
||||
// NOT see a re-login warning that doesn't reflect the actual cause.
|
||||
if strings.Contains(out, "level=WARN") {
|
||||
t.Fatalf("transient 500 should not log at WARN, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "token renewal failed") {
|
||||
t.Fatalf("expected debug log about renewal failure, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPreflightAuth_RenewsBeforeWorkspaceSyncOnExpiredToken locks in the
|
||||
// must-fix from MUL-2744 review: when the daemon starts with an already-
|
||||
// revoked or expired PAT, the renewal call has to happen BEFORE the first
|
||||
// workspace sync, because the workspace sync's 401 would short-circuit Run
|
||||
// and the operator would never see a "run multica login" hint.
|
||||
func TestPreflightAuth_RenewsBeforeWorkspaceSyncOnExpiredToken(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
var seen []string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
seen = append(seen, r.URL.Path)
|
||||
mu.Unlock()
|
||||
// Both endpoints 401 — this is the "PAT already revoked/expired
|
||||
// before the daemon even started" failure mode.
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid token"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
var buf bytes.Buffer
|
||||
d := &Daemon{client: NewClient(srv.URL), logger: captureLogger(&buf)}
|
||||
d.client.SetToken("mul_already_revoked")
|
||||
|
||||
err := d.preflightAuth(context.Background())
|
||||
if err == nil {
|
||||
t.Fatal("expected workspace sync to fail with 401")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(seen) < 2 {
|
||||
t.Fatalf("expected both endpoints to be called; got %v", seen)
|
||||
}
|
||||
if seen[0] != "/api/tokens/current/renew" {
|
||||
t.Fatalf("renew must be the first API call so the WARN fires before the sync 401s; got order %v", seen)
|
||||
}
|
||||
if seen[1] != "/api/workspaces" {
|
||||
t.Fatalf("workspace sync should follow renew; got order %v", seen)
|
||||
}
|
||||
out := buf.String()
|
||||
if !strings.Contains(out, "level=WARN") {
|
||||
t.Fatalf("expected re-login WARN, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "multica login") {
|
||||
t.Fatalf("expected the actionable 'run multica login' hint in the WARN, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPreflightAuth_SyncProceedsWhenRenewIsNoOp covers the steady-state
|
||||
// startup: a PAT well outside the renewal window returns renewed=false,
|
||||
// and preflightAuth must still go on to do the workspace sync. The
|
||||
// renewal is best-effort and must not gate startup.
|
||||
func TestPreflightAuth_SyncProceedsWhenRenewIsNoOp(t *testing.T) {
|
||||
var syncCalled atomic.Bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tokens/current/renew":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"expires_at": "2099-01-02T03:04:05Z",
|
||||
"renewed": false,
|
||||
})
|
||||
case "/api/workspaces":
|
||||
syncCalled.Store(true)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[]`))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
var buf bytes.Buffer
|
||||
d := &Daemon{client: NewClient(srv.URL), logger: captureLogger(&buf)}
|
||||
d.client.SetToken("mul_healthy")
|
||||
|
||||
if err := d.preflightAuth(context.Background()); err != nil {
|
||||
t.Fatalf("preflightAuth returned error on healthy startup: %v", err)
|
||||
}
|
||||
if !syncCalled.Load() {
|
||||
t.Fatal("preflightAuth must run the workspace sync after a no-op renewal")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPreflightAuth_TransientRenewFailureDoesNotBlockStartup covers the
|
||||
// "renewal endpoint is briefly down" path. The renewal failure must not
|
||||
// kill the daemon — the workspace sync still happens, and the daemon is
|
||||
// up and serving. The background renewal loop will retry later.
|
||||
func TestPreflightAuth_TransientRenewFailureDoesNotBlockStartup(t *testing.T) {
|
||||
var syncCalled atomic.Bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/tokens/current/renew":
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"db down"}`))
|
||||
case "/api/workspaces":
|
||||
syncCalled.Store(true)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[]`))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
var buf bytes.Buffer
|
||||
d := &Daemon{client: NewClient(srv.URL), logger: captureLogger(&buf)}
|
||||
d.client.SetToken("mul_healthy")
|
||||
|
||||
if err := d.preflightAuth(context.Background()); err != nil {
|
||||
t.Fatalf("preflightAuth must not surface transient renew failures: %v", err)
|
||||
}
|
||||
if !syncCalled.Load() {
|
||||
t.Fatal("transient renew failure must not skip the workspace sync")
|
||||
}
|
||||
if strings.Contains(buf.String(), "level=WARN") {
|
||||
t.Fatalf("transient 500 must not emit the re-login WARN, got: %s", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryRenewToken_RespectsContextTimeout(t *testing.T) {
|
||||
// Server that never responds — the per-call 15s timeout inside
|
||||
// tryRenewToken is too long for a unit test, so cancel the parent
|
||||
// context immediately and verify the call returns.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
_, _ = io.Copy(io.Discard, r.Body)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
var buf bytes.Buffer
|
||||
d := &Daemon{client: NewClient(srv.URL), logger: captureLogger(&buf)}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
d.tryRenewToken(ctx)
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
// Expected: tryRenewToken returns once the cancelled ctx propagates
|
||||
// through the HTTP client.
|
||||
case <-context.Background().Done():
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -13,6 +14,19 @@ import (
|
||||
db "github.com/multica-ai/multica/server/pkg/db/generated"
|
||||
)
|
||||
|
||||
// PATRenewThreshold is the remaining-lifetime window at which a PAT becomes
|
||||
// eligible for an in-place renewal. The daemon polls every ~3 days, so a 7-day
|
||||
// threshold guarantees at least one renewal attempt while the token still has
|
||||
// ≥ 4 days of validity left — enough margin to absorb a transient network
|
||||
// failure before the user actually has to re-run `multica login`.
|
||||
const PATRenewThreshold = 7 * 24 * time.Hour
|
||||
|
||||
// PATRenewExtension is how far into the future a renewed PAT's expires_at is
|
||||
// pushed. Matches the initial issuance window in CreatePersonalAccessToken
|
||||
// (90 days) so renewed tokens converge on the same lifetime as freshly minted
|
||||
// ones — no second-class renewed tokens.
|
||||
const PATRenewExtension = 90 * 24 * time.Hour
|
||||
|
||||
type PersonalAccessTokenResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -115,6 +129,129 @@ func (h *Handler) ListPersonalAccessTokens(w http.ResponseWriter, r *http.Reques
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// RenewPATResponse is the body returned by RenewCurrentPersonalAccessToken.
|
||||
//
|
||||
// Renewed=false is a no-op, not an error — it just means the caller polled
|
||||
// before the token entered the renewal window. Callers should always read
|
||||
// ExpiresAt for the authoritative expiry rather than assuming the old value
|
||||
// is still current.
|
||||
type RenewPATResponse struct {
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
Renewed bool `json:"renewed"`
|
||||
}
|
||||
|
||||
// RenewCurrentPersonalAccessToken extends the expires_at of the PAT used to
|
||||
// authenticate this request, in-place, when it is inside the renewal window.
|
||||
//
|
||||
// The endpoint deliberately does NOT mint a new token — that would require
|
||||
// either rotating the raw secret (breaks the CLI/daemon multi-process model,
|
||||
// where a single PAT is shared by every process started from the same CLI
|
||||
// config) or returning the raw token over the wire on every poll (a needless
|
||||
// exposure since the daemon already holds it). Instead we extend the row's
|
||||
// expires_at atomically; the cached PAT entry's TTL is short enough
|
||||
// (auth.AuthCacheTTL ≤ 10m) that the cache catches up to the new expiry on
|
||||
// the next cache miss without an explicit invalidation.
|
||||
//
|
||||
// Only mul_ PATs may be renewed: a cookie/JWT session has no PAT row to
|
||||
// extend, and an mat_ task token is single-purpose and short-lived. mcn_
|
||||
// cloud-node PATs are owned by Multica Cloud Fleet, not us — we don't even
|
||||
// see the expiry locally.
|
||||
func (h *Handler) RenewCurrentPersonalAccessToken(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := requireUserID(w, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Re-read the raw token from the Authorization header — the upstream Auth
|
||||
// middleware resolves it to a userID but doesn't pass the hash forward,
|
||||
// and we need the row, not just the user.
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
rawToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if rawToken == "" || rawToken == authHeader || !strings.HasPrefix(rawToken, "mul_") {
|
||||
writeError(w, http.StatusBadRequest, "only personal access tokens can be renewed")
|
||||
return
|
||||
}
|
||||
|
||||
hash := auth.HashToken(rawToken)
|
||||
pat, err := h.Queries.GetPersonalAccessTokenByHash(r.Context(), hash)
|
||||
if err != nil {
|
||||
// The Auth middleware already validated the token, so reaching here
|
||||
// with no row means the PAT was revoked or expired in the gap between
|
||||
// the middleware's cache hit and this DB read. Surface a 401 so the
|
||||
// daemon's 401 branch fires the same "please re-login" message it
|
||||
// would for any other auth failure, instead of a generic 500.
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
writeError(w, http.StatusUnauthorized, "token is no longer valid")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "failed to look up token")
|
||||
return
|
||||
}
|
||||
|
||||
// Defense in depth: the middleware already set X-User-ID from the same
|
||||
// PAT row, so this mismatch should be impossible. If it ever fires, it
|
||||
// means a header was forged past the middleware and we MUST refuse to
|
||||
// renew on someone else's behalf — fail loudly.
|
||||
if uuidToString(pat.UserID) != userID {
|
||||
writeError(w, http.StatusUnauthorized, "token does not belong to caller")
|
||||
return
|
||||
}
|
||||
|
||||
// PATs minted before this code existed may have a NULL expires_at (the
|
||||
// "never expires" case). There is nothing to extend — return the current
|
||||
// (absent) expiry and let the caller treat this as a permanent token.
|
||||
if !pat.ExpiresAt.Valid {
|
||||
writeJSON(w, http.StatusOK, RenewPATResponse{ExpiresAt: "", Renewed: false})
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
remaining := pat.ExpiresAt.Time.Sub(now)
|
||||
if remaining > PATRenewThreshold {
|
||||
writeJSON(w, http.StatusOK, RenewPATResponse{
|
||||
ExpiresAt: timestampToString(pat.ExpiresAt),
|
||||
Renewed: false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
newExpiresAt := pgtype.Timestamptz{Time: now.Add(PATRenewExtension), Valid: true}
|
||||
// Pass the renewal threshold as the CAS predicate: only update if the
|
||||
// row's existing expires_at is still inside this window. After the
|
||||
// first writer succeeds the row sits at now+90d, which is well past
|
||||
// now+7d, so any concurrent renewer hits the WHERE and sees ErrNoRows.
|
||||
renewThreshold := pgtype.Timestamptz{Time: now.Add(PATRenewThreshold), Valid: true}
|
||||
updated, err := h.Queries.ExtendPersonalAccessTokenExpiry(r.Context(), db.ExtendPersonalAccessTokenExpiryParams{
|
||||
ID: pat.ID,
|
||||
NewExpiresAt: newExpiresAt,
|
||||
RenewThresholdAt: renewThreshold,
|
||||
})
|
||||
switch {
|
||||
case err == nil:
|
||||
writeJSON(w, http.StatusOK, RenewPATResponse{
|
||||
ExpiresAt: timestampToString(updated),
|
||||
Renewed: true,
|
||||
})
|
||||
case errors.Is(err, pgx.ErrNoRows):
|
||||
// A concurrent renew (or revoke) won the race. Re-read the current
|
||||
// row and report what's there now — the daemon's only correctness
|
||||
// guarantee is "after a successful call, expires_at is fresh enough
|
||||
// to last until the next poll", and a parallel writer already
|
||||
// satisfied that, so this is success from the caller's POV.
|
||||
current, getErr := h.Queries.GetPersonalAccessTokenByHash(r.Context(), hash)
|
||||
if getErr != nil {
|
||||
writeError(w, http.StatusUnauthorized, "token is no longer valid")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, RenewPATResponse{
|
||||
ExpiresAt: timestampToString(current.ExpiresAt),
|
||||
Renewed: false,
|
||||
})
|
||||
default:
|
||||
writeError(w, http.StatusInternalServerError, "failed to renew token")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) RevokePersonalAccessToken(w http.ResponseWriter, r *http.Request) {
|
||||
userID, ok := requireUserID(w, r)
|
||||
if !ok {
|
||||
|
||||
367
server/internal/handler/personal_access_token_test.go
Normal file
367
server/internal/handler/personal_access_token_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/multica-ai/multica/server/internal/auth"
|
||||
db "github.com/multica-ai/multica/server/pkg/db/generated"
|
||||
)
|
||||
|
||||
// insertTestPAT creates a PAT row for the shared test user with the given
|
||||
// expiry and returns (rawToken, patID). Each call generates a fresh raw token
|
||||
// so a test can hold many independent rows without colliding on token_hash.
|
||||
// The row is auto-cleaned at test end.
|
||||
func insertTestPAT(t *testing.T, expiresAt time.Time) (string, string) {
|
||||
t.Helper()
|
||||
raw, err := auth.GeneratePATToken()
|
||||
if err != nil {
|
||||
t.Fatalf("generate pat: %v", err)
|
||||
}
|
||||
prefix := raw
|
||||
if len(prefix) > 12 {
|
||||
prefix = prefix[:12]
|
||||
}
|
||||
pat, err := testHandler.Queries.CreatePersonalAccessToken(context.Background(), db.CreatePersonalAccessTokenParams{
|
||||
UserID: parseUUID(testUserID),
|
||||
Name: "renew-test",
|
||||
TokenHash: auth.HashToken(raw),
|
||||
TokenPrefix: prefix,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: expiresAt, Valid: !expiresAt.IsZero()},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create pat: %v", err)
|
||||
}
|
||||
patID := uuidToString(pat.ID)
|
||||
t.Cleanup(func() {
|
||||
testPool.Exec(context.Background(), `DELETE FROM personal_access_token WHERE id = $1`, parseUUID(patID))
|
||||
})
|
||||
return raw, patID
|
||||
}
|
||||
|
||||
// newRenewRequest builds a POST /api/tokens/current/renew request with both
|
||||
// the X-User-ID and Authorization headers set, so the handler can resolve
|
||||
// the PAT row in addition to the caller's user.
|
||||
func newRenewRequest(rawToken string) *http.Request {
|
||||
req := newRequest("POST", "/api/tokens/current/renew", nil)
|
||||
if rawToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+rawToken)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func decodeRenewResponse(t *testing.T, body *httptest.ResponseRecorder) RenewPATResponse {
|
||||
t.Helper()
|
||||
var resp RenewPATResponse
|
||||
if err := json.NewDecoder(body.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode renew response: %v (body: %s)", err, body.Body.String())
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func TestRenewPAT_ExtendsWhenInsideRenewalWindow(t *testing.T) {
|
||||
// 3 days remaining — well inside the 7-day threshold.
|
||||
oldExpiry := time.Now().Add(3 * 24 * time.Hour)
|
||||
raw, patID := insertTestPAT(t, oldExpiry)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
testHandler.RenewCurrentPersonalAccessToken(w, newRenewRequest(raw))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
resp := decodeRenewResponse(t, w)
|
||||
if !resp.Renewed {
|
||||
t.Fatalf("expected renewed=true, got false (expires_at=%s)", resp.ExpiresAt)
|
||||
}
|
||||
|
||||
var actual time.Time
|
||||
if err := testPool.QueryRow(context.Background(),
|
||||
`SELECT expires_at FROM personal_access_token WHERE id = $1`, parseUUID(patID),
|
||||
).Scan(&actual); err != nil {
|
||||
t.Fatalf("readback: %v", err)
|
||||
}
|
||||
// Renewed expiry should be roughly now + PATRenewExtension (90 days),
|
||||
// well past the old expiry. Use a wide window — the test only needs to
|
||||
// know the row was bumped, not the exact instant.
|
||||
if !actual.After(oldExpiry.Add(24 * time.Hour)) {
|
||||
t.Fatalf("expected new expiry to be far past old %v, got %v", oldExpiry, actual)
|
||||
}
|
||||
wantAround := time.Now().Add(PATRenewExtension)
|
||||
if actual.Before(wantAround.Add(-time.Hour)) || actual.After(wantAround.Add(time.Hour)) {
|
||||
t.Fatalf("expected new expiry near %v, got %v", wantAround, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewPAT_NoOpWhenOutsideRenewalWindow(t *testing.T) {
|
||||
// 30 days remaining — well outside the 7-day threshold.
|
||||
oldExpiry := time.Now().Add(30 * 24 * time.Hour).Truncate(time.Second)
|
||||
raw, patID := insertTestPAT(t, oldExpiry)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
testHandler.RenewCurrentPersonalAccessToken(w, newRenewRequest(raw))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
resp := decodeRenewResponse(t, w)
|
||||
if resp.Renewed {
|
||||
t.Fatalf("expected renewed=false, got true (expires_at=%s)", resp.ExpiresAt)
|
||||
}
|
||||
|
||||
var actual time.Time
|
||||
if err := testPool.QueryRow(context.Background(),
|
||||
`SELECT expires_at FROM personal_access_token WHERE id = $1`, parseUUID(patID),
|
||||
).Scan(&actual); err != nil {
|
||||
t.Fatalf("readback: %v", err)
|
||||
}
|
||||
if !actual.Equal(oldExpiry) {
|
||||
t.Fatalf("no-op should not change expires_at; old=%v new=%v", oldExpiry, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewPAT_RejectsExpiredToken(t *testing.T) {
|
||||
raw, _ := insertTestPAT(t, time.Now().Add(-time.Hour))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
testHandler.RenewCurrentPersonalAccessToken(w, newRenewRequest(raw))
|
||||
// Expired tokens are filtered by GetPersonalAccessTokenByHash, so the
|
||||
// handler reports 401 — the auth middleware in production would already
|
||||
// have rejected the request, but the handler defends in depth.
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401 for expired token, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewPAT_RejectsRevokedToken(t *testing.T) {
|
||||
raw, patID := insertTestPAT(t, time.Now().Add(3*24*time.Hour))
|
||||
if _, err := testPool.Exec(context.Background(),
|
||||
`UPDATE personal_access_token SET revoked = TRUE WHERE id = $1`, parseUUID(patID),
|
||||
); err != nil {
|
||||
t.Fatalf("revoke: %v", err)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
testHandler.RenewCurrentPersonalAccessToken(w, newRenewRequest(raw))
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401 for revoked token, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewPAT_RejectsNonPATAuthHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"missing bearer prefix", "mul_abc123"},
|
||||
{"wrong prefix", "Bearer mdt_abc123"},
|
||||
{"jwt", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.sig"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := newRequest("POST", "/api/tokens/current/renew", nil)
|
||||
if tt.header != "" {
|
||||
req.Header.Set("Authorization", tt.header)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
testHandler.RenewCurrentPersonalAccessToken(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewPAT_HandlesNullExpiresAt(t *testing.T) {
|
||||
// Pre-existing PATs may carry NULL expires_at; the handler returns
|
||||
// renewed=false with an empty expires_at field rather than failing.
|
||||
raw, _ := insertTestPAT(t, time.Time{})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
testHandler.RenewCurrentPersonalAccessToken(w, newRenewRequest(raw))
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
resp := decodeRenewResponse(t, w)
|
||||
if resp.Renewed {
|
||||
t.Fatalf("expected renewed=false for NULL expiry, got true")
|
||||
}
|
||||
if resp.ExpiresAt != "" {
|
||||
t.Fatalf("expected empty expires_at for NULL expiry, got %q", resp.ExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewPAT_ConcurrentRenewIsIdempotent(t *testing.T) {
|
||||
// Two callers race to extend the same PAT. The WHERE clause on
|
||||
// ExtendPersonalAccessTokenExpiry guarantees only one UPDATE actually
|
||||
// bumps the row; the loser sees pgx.ErrNoRows and reports renewed=false
|
||||
// with the already-extended expires_at. Both calls return 200.
|
||||
raw, patID := insertTestPAT(t, time.Now().Add(2*24*time.Hour))
|
||||
|
||||
w1 := httptest.NewRecorder()
|
||||
testHandler.RenewCurrentPersonalAccessToken(w1, newRenewRequest(raw))
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Fatalf("first renew: expected 200, got %d: %s", w1.Code, w1.Body.String())
|
||||
}
|
||||
resp1 := decodeRenewResponse(t, w1)
|
||||
if !resp1.Renewed {
|
||||
t.Fatal("first renew should have extended the row")
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
testHandler.RenewCurrentPersonalAccessToken(w2, newRenewRequest(raw))
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Fatalf("second renew: expected 200, got %d: %s", w2.Code, w2.Body.String())
|
||||
}
|
||||
resp2 := decodeRenewResponse(t, w2)
|
||||
if resp2.Renewed {
|
||||
t.Fatal("second renew should be a no-op (token now far in the future)")
|
||||
}
|
||||
if resp2.ExpiresAt != resp1.ExpiresAt {
|
||||
t.Fatalf("second renew should report same expires_at as first; got %q vs %q",
|
||||
resp2.ExpiresAt, resp1.ExpiresAt)
|
||||
}
|
||||
|
||||
// And the DB only carries the single extended value.
|
||||
var actual time.Time
|
||||
if err := testPool.QueryRow(context.Background(),
|
||||
`SELECT expires_at FROM personal_access_token WHERE id = $1`, parseUUID(patID),
|
||||
).Scan(&actual); err != nil {
|
||||
t.Fatalf("readback: %v", err)
|
||||
}
|
||||
wantAround := time.Now().Add(PATRenewExtension)
|
||||
if actual.Before(wantAround.Add(-time.Hour)) || actual.After(wantAround.Add(time.Hour)) {
|
||||
t.Fatalf("expected expiry near %v, got %v", wantAround, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRenewPAT_ParallelRenewExtendsExactlyOnce locks in the SQL-level
|
||||
// idempotency that the MUL-2744 review flagged: when N callers race to
|
||||
// renew the same in-window PAT, the WHERE clause must ensure only one
|
||||
// UPDATE actually bumps the row. The previous condition (`expires_at < $2`)
|
||||
// silently let every caller win — each computed a slightly larger
|
||||
// `$2 = now + 90d`, so the second writer's $2 always exceeded the first
|
||||
// writer's row value and the UPDATE re-matched. Pinning the CAS to the
|
||||
// renewal threshold instead (`expires_at <= $3`) means after the first
|
||||
// writer pushes expires_at to now + 90d, all subsequent writers see a
|
||||
// row already past the threshold and the UPDATE matches zero rows.
|
||||
//
|
||||
// We verify the database side by counting how many times the row's
|
||||
// expires_at column was actually moved across N parallel calls.
|
||||
func TestRenewPAT_ParallelRenewExtendsExactlyOnce(t *testing.T) {
|
||||
const concurrency = 8
|
||||
|
||||
// Token has 2 days remaining — comfortably inside the 7-day window so
|
||||
// every caller passes the handler's threshold pre-check and all of
|
||||
// them get a chance to fight at the SQL layer.
|
||||
oldExpiry := time.Now().Add(2 * 24 * time.Hour)
|
||||
raw, patID := insertTestPAT(t, oldExpiry)
|
||||
|
||||
type result struct {
|
||||
code int
|
||||
expiresAt string
|
||||
renewed bool
|
||||
}
|
||||
results := make([]result, concurrency)
|
||||
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
w := httptest.NewRecorder()
|
||||
testHandler.RenewCurrentPersonalAccessToken(w, newRenewRequest(raw))
|
||||
var resp RenewPATResponse
|
||||
_ = json.NewDecoder(w.Body).Decode(&resp)
|
||||
results[i] = result{code: w.Code, expiresAt: resp.ExpiresAt, renewed: resp.Renewed}
|
||||
}(i)
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
var winners int
|
||||
var winnerExpiry string
|
||||
for _, r := range results {
|
||||
if r.code != http.StatusOK {
|
||||
t.Fatalf("concurrent renew should never return non-200; got %d (renewed=%v expires_at=%q)", r.code, r.renewed, r.expiresAt)
|
||||
}
|
||||
if r.renewed {
|
||||
winners++
|
||||
winnerExpiry = r.expiresAt
|
||||
}
|
||||
}
|
||||
if winners != 1 {
|
||||
t.Fatalf("expected exactly one caller to flip renewed=true; got %d winners across %d calls", winners, concurrency)
|
||||
}
|
||||
|
||||
// All losing callers report the same already-extended expires_at, and
|
||||
// the DB carries that same value. If the old (buggy) condition were
|
||||
// still in place, several callers would have re-bumped the row to
|
||||
// strictly-larger now+90d values and the final expiry would not match
|
||||
// the first winner's response.
|
||||
var finalExpiry time.Time
|
||||
if err := testPool.QueryRow(context.Background(),
|
||||
`SELECT expires_at FROM personal_access_token WHERE id = $1`, parseUUID(patID),
|
||||
).Scan(&finalExpiry); err != nil {
|
||||
t.Fatalf("readback: %v", err)
|
||||
}
|
||||
finalAsString := timestampToString(pgtype.Timestamptz{Time: finalExpiry, Valid: true})
|
||||
if winnerExpiry != "" && finalAsString != winnerExpiry {
|
||||
t.Fatalf("DB expires_at must match the winner's response (no double-bump); db=%q winner=%q", finalAsString, winnerExpiry)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewPAT_RejectsTokenBelongingToDifferentUser(t *testing.T) {
|
||||
// Mint a PAT for a different user, then send a request that pairs that
|
||||
// PAT's Authorization header with our shared test user's X-User-ID
|
||||
// (simulating a forged identity past the middleware). The handler MUST
|
||||
// refuse to renew on the wrong user's behalf.
|
||||
ctx := context.Background()
|
||||
var otherUserID string
|
||||
if err := testPool.QueryRow(ctx, `
|
||||
INSERT INTO "user" (name, email)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
`, "Other User", "other-renew@multica.ai").Scan(&otherUserID); err != nil {
|
||||
t.Fatalf("create other user: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
testPool.Exec(ctx, `DELETE FROM "user" WHERE id = $1`, parseUUID(otherUserID))
|
||||
})
|
||||
|
||||
raw, err := auth.GeneratePATToken()
|
||||
if err != nil {
|
||||
t.Fatalf("generate pat: %v", err)
|
||||
}
|
||||
prefix := raw
|
||||
if len(prefix) > 12 {
|
||||
prefix = prefix[:12]
|
||||
}
|
||||
pat, err := testHandler.Queries.CreatePersonalAccessToken(ctx, db.CreatePersonalAccessTokenParams{
|
||||
UserID: parseUUID(otherUserID),
|
||||
Name: "other-renew",
|
||||
TokenHash: auth.HashToken(raw),
|
||||
TokenPrefix: prefix,
|
||||
ExpiresAt: pgtype.Timestamptz{Time: time.Now().Add(3 * 24 * time.Hour), Valid: true},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create other pat: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
testPool.Exec(ctx, `DELETE FROM personal_access_token WHERE id = $1`, pat.ID)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
// newRequest sets X-User-ID = testUserID, but the bearer is otherUser's PAT.
|
||||
testHandler.RenewCurrentPersonalAccessToken(w, newRenewRequest(raw))
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401 on user mismatch, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -48,6 +48,41 @@ func (q *Queries) CreatePersonalAccessToken(ctx context.Context, arg CreatePerso
|
||||
return i, err
|
||||
}
|
||||
|
||||
const extendPersonalAccessTokenExpiry = `-- name: ExtendPersonalAccessTokenExpiry :one
|
||||
UPDATE personal_access_token
|
||||
SET expires_at = $1
|
||||
WHERE id = $2
|
||||
AND revoked = FALSE
|
||||
AND expires_at IS NOT NULL
|
||||
AND expires_at > now()
|
||||
AND expires_at <= $3
|
||||
RETURNING expires_at
|
||||
`
|
||||
|
||||
type ExtendPersonalAccessTokenExpiryParams struct {
|
||||
NewExpiresAt pgtype.Timestamptz `json:"new_expires_at"`
|
||||
ID pgtype.UUID `json:"id"`
|
||||
RenewThresholdAt pgtype.Timestamptz `json:"renew_threshold_at"`
|
||||
}
|
||||
|
||||
// In-place renew: only bumps expires_at when the token is still valid
|
||||
// (not revoked, not already expired) AND the existing expires_at is
|
||||
// still inside the renewal threshold ($3, e.g. now + 7d). Phrasing the
|
||||
// CAS this way — "is the row still renewable?" rather than "is the
|
||||
// requested new expiry larger than the current one?" — makes concurrent
|
||||
// renews idempotent: once writer A bumps expires_at past the threshold,
|
||||
// writer B's UPDATE matches zero rows (sqlc :one returns pgx.ErrNoRows,
|
||||
// which the caller treats as "already renewed"). A naive `expires_at <
|
||||
// $2` would still match because two callers race-computing
|
||||
// `$2 = now + 90d` produce strictly-different values and the second
|
||||
// one's $2 is always greater than the row A just wrote.
|
||||
func (q *Queries) ExtendPersonalAccessTokenExpiry(ctx context.Context, arg ExtendPersonalAccessTokenExpiryParams) (pgtype.Timestamptz, error) {
|
||||
row := q.db.QueryRow(ctx, extendPersonalAccessTokenExpiry, arg.NewExpiresAt, arg.ID, arg.RenewThresholdAt)
|
||||
var expires_at pgtype.Timestamptz
|
||||
err := row.Scan(&expires_at)
|
||||
return expires_at, err
|
||||
}
|
||||
|
||||
const getPersonalAccessTokenByHash = `-- name: GetPersonalAccessTokenByHash :one
|
||||
SELECT id, user_id, name, token_hash, token_prefix, expires_at, last_used_at, revoked, created_at FROM personal_access_token
|
||||
WHERE token_hash = $1
|
||||
|
||||
@@ -25,3 +25,24 @@ RETURNING token_hash;
|
||||
UPDATE personal_access_token
|
||||
SET last_used_at = now()
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: ExtendPersonalAccessTokenExpiry :one
|
||||
-- In-place renew: only bumps expires_at when the token is still valid
|
||||
-- (not revoked, not already expired) AND the existing expires_at is
|
||||
-- still inside the renewal threshold ($3, e.g. now + 7d). Phrasing the
|
||||
-- CAS this way — "is the row still renewable?" rather than "is the
|
||||
-- requested new expiry larger than the current one?" — makes concurrent
|
||||
-- renews idempotent: once writer A bumps expires_at past the threshold,
|
||||
-- writer B's UPDATE matches zero rows (sqlc :one returns pgx.ErrNoRows,
|
||||
-- which the caller treats as "already renewed"). A naive `expires_at <
|
||||
-- $2` would still match because two callers race-computing
|
||||
-- `$2 = now + 90d` produce strictly-different values and the second
|
||||
-- one's $2 is always greater than the row A just wrote.
|
||||
UPDATE personal_access_token
|
||||
SET expires_at = sqlc.arg(new_expires_at)
|
||||
WHERE id = sqlc.arg(id)
|
||||
AND revoked = FALSE
|
||||
AND expires_at IS NOT NULL
|
||||
AND expires_at > now()
|
||||
AND expires_at <= sqlc.arg(renew_threshold_at)
|
||||
RETURNING expires_at;
|
||||
|
||||
Reference in New Issue
Block a user