diff --git a/go.mod b/go.mod index d753823..bad5a2f 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/tyler-smith/go-bip39 v1.1.0 github.com/valyala/fastjson v1.6.3 golang.org/x/exp v0.0.0-20221106115401-f9659909a136 + golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 ) require ( diff --git a/go.sum b/go.sum index 08af53f..29bdf5d 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,7 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnk golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20221106115401-f9659909a136 h1:Fq7F/w7MAa1KJ5bt2aJ62ihqp9HDcRuyILskkpIAurw= golang.org/x/exp v0.0.0-20221106115401-f9659909a136/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2ePZzZTUrRacwib7cNsYQ= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/relay.go b/relay.go index f98d246..7cac807 100644 --- a/relay.go +++ b/relay.go @@ -1,6 +1,7 @@ package nostr import ( + "context" "encoding/hex" "encoding/json" "fmt" @@ -45,9 +46,16 @@ type Relay struct { statusChans s.MapOf[string, chan Status] } +// RelayConnect forwards calls to RelayConnectContext with a background context. func RelayConnect(url string) (*Relay, error) { + return RelayConnectContext(context.Background(), url) +} + +// RelayConnectContext creates a new relay client and connects to a canonical +// URL using Relay.ConnectContext, passing ctx as is. +func RelayConnectContext(ctx context.Context, url string) (*Relay, error) { r := &Relay{URL: NormalizeURL(url)} - err := r.Connect() + err := r.ConnectContext(ctx) return r, err } @@ -55,12 +63,21 @@ func (r *Relay) String() string { return r.URL } +// Connect calls ConnectContext with a background context. func (r *Relay) Connect() error { + return r.ConnectContext(context.Background()) +} + +// ConnectContext tries to establish a websocket connection to r.URL. +// If the context expires before the connection is complete, an error is returned. +// Once successfully connected, context expiration has no effect: call r.Close +// to close the connection. +func (r *Relay) ConnectContext(ctx context.Context) error { if r.URL == "" { return fmt.Errorf("invalid relay URL '%s'", r.URL) } - socket, _, err := websocket.DefaultDialer.Dial(r.URL, nil) + socket, _, err := websocket.DefaultDialer.DialContext(ctx, r.URL, nil) if err != nil { return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) } diff --git a/relay_test.go b/relay_test.go new file mode 100644 index 0000000..f827216 --- /dev/null +++ b/relay_test.go @@ -0,0 +1,72 @@ +package nostr + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +func TestConnectContext(t *testing.T) { + // fake relay server + var mu sync.Mutex // guards connected to satisfy go test -race + var connected bool + ws := newWebsocketServer(func(conn *websocket.Conn) { + mu.Lock() + connected = true + mu.Unlock() + io.ReadAll(conn) // discard all input + }) + defer ws.Close() + + // relay client + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + r, err := RelayConnectContext(ctx, ws.URL) + if err != nil { + t.Fatalf("RelayConnectContext: %v", err) + } + defer r.Close() + + mu.Lock() + defer mu.Unlock() + if !connected { + t.Error("fake relay server saw no client connect") + } +} + +func TestConnectContextCanceled(t *testing.T) { + // fake relay server + ws := newWebsocketServer(func(conn *websocket.Conn) { + io.ReadAll(conn) // discard all input + }) + defer ws.Close() + + // relay client + ctx, cancel := context.WithCancel(context.Background()) + cancel() // make ctx expired + _, err := RelayConnectContext(ctx, ws.URL) + if !errors.Is(err, context.Canceled) { + t.Errorf("RelayConnectContext returned %v error; want context.Canceled", err) + } +} + +func newWebsocketServer(handler func(*websocket.Conn)) *httptest.Server { + return httptest.NewServer(&websocket.Server{ + Handshake: anyOriginHandshake, + Handler: handler, + }) +} + +// anyOriginHandshake is an alternative to default in golang.org/x/net/websocket +// which checks for origin. nostr client sends no origin and it makes no difference +// for the tests here anyway. +var anyOriginHandshake = func(conf *websocket.Config, r *http.Request) error { + return nil +} diff --git a/relaypool.go b/relaypool.go index 3c4edd4..ddf557c 100644 --- a/relaypool.go +++ b/relaypool.go @@ -1,6 +1,7 @@ package nostr import ( + "context" "crypto/rand" "encoding/hex" "errors" @@ -58,44 +59,57 @@ func NewRelayPool() *RelayPool { } } -// Add adds a new relay to the pool, if policy is nil, it will be a simple -// read+write policy. -func (r *RelayPool) Add(url string, policy RelayPoolPolicy) chan error { +// Add calls AddContext with background context in a separate goroutine, sending +// any connection error over the returned channel. +// +// The returned channel is closed once the connection is successfully +// established or RelayConnectContext returned an error. +func (r *RelayPool) Add(url string, policy RelayPoolPolicy) <-chan error { + cherr := make(chan error) + go func() { + defer close(cherr) + if err := r.AddContext(context.Background(), url, policy); err != nil { + cherr <- err + } + }() + return cherr +} + +// AddContext connects to a relay at a canonical version specified by the url +// and adds it to the pool. The returned error is non-nil only on connection +// errors, including an expired context before the connection is complete. +// +// Once successfully connected, AddContext returns and the context expiration +// has no effect: call r.Remove to close the connection and delete a relay from the pool. +func (r *RelayPool) AddContext(ctx context.Context, url string, policy RelayPoolPolicy) error { + relay, err := RelayConnectContext(ctx, url) + if err != nil { + return fmt.Errorf("failed to connect to %s: %w", url, err) + } if policy == nil { policy = SimplePolicy{Read: true, Write: true} } + r.addConnected(relay, policy) + return nil +} - cherr := make(chan error) +func (r *RelayPool) addConnected(relay *Relay, policy RelayPoolPolicy) { + r.Policies.Store(relay.URL, policy) + r.Relays.Store(relay.URL, relay) - go func() { - relay, err := RelayConnect(url) - if err != nil { - cherr <- fmt.Errorf("failed to connect to %s: %w", url, err) - return - } + r.subscriptions.Range(func(id string, filters Filters) bool { + sub := relay.prepareSubscription(id) + sub.Sub(filters) + eventStream, _ := r.eventStreams.Load(id) - r.Policies.Store(relay.URL, policy) - r.Relays.Store(relay.URL, relay) + go func(sub *Subscription) { + for evt := range sub.Events { + eventStream <- EventMessage{Relay: relay.URL, Event: evt} + } + }(sub) - r.subscriptions.Range(func(id string, filters Filters) bool { - sub := relay.prepareSubscription(id) - sub.Sub(filters) - eventStream, _ := r.eventStreams.Load(id) - - go func(sub *Subscription) { - for evt := range sub.Events { - eventStream <- EventMessage{Relay: relay.URL, Event: evt} - } - }(sub) - - return true - }) - - cherr <- nil - close(cherr) - }() - - return cherr + return true + }) } // Remove removes a relay from the pool.