diff --git a/relay.go b/relay.go index 2000a08..d1f7ba4 100644 --- a/relay.go +++ b/relay.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "log" + "net/http" "sync" "time" @@ -36,7 +37,8 @@ func (s Status) String() string { } type Relay struct { - URL string + URL string + RequestHeader http.Header // e.g. for origin header Connection *Connection subscriptions s.MapOf[string, *Subscription] @@ -77,7 +79,7 @@ func (r *Relay) Connect(ctx context.Context) error { defer cancel() } - socket, _, err := websocket.DefaultDialer.DialContext(ctx, r.URL, nil) + socket, _, err := websocket.DefaultDialer.DialContext(ctx, r.URL, r.RequestHeader) if err != nil { return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) } diff --git a/relay_test.go b/relay_test.go index 31168d3..55267ca 100644 --- a/relay_test.go +++ b/relay_test.go @@ -121,9 +121,7 @@ func TestConnectContext(t *testing.T) { func TestConnectContextCanceled(t *testing.T) { // fake relay server - ws := newWebsocketServer(func(conn *websocket.Conn) { - io.ReadAll(conn) // discard all input - }) + ws := newWebsocketServer(discardingHandler) defer ws.Close() // relay client @@ -135,6 +133,26 @@ func TestConnectContextCanceled(t *testing.T) { } } +func TestConnectWithOrigin(t *testing.T) { + // fake relay server + // default handler requires origin golang.org/x/net/websocket + ws := httptest.NewServer(websocket.Handler(discardingHandler)) + defer ws.Close() + + // relay client + r := &Relay{URL: NormalizeURL(ws.URL), RequestHeader: http.Header{"origin": {"https://example.com"}}} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err := r.Connect(ctx) + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func discardingHandler(conn *websocket.Conn) { + io.ReadAll(conn) // discard all input +} + func newWebsocketServer(handler func(*websocket.Conn)) *httptest.Server { return httptest.NewServer(&websocket.Server{ Handshake: anyOriginHandshake,