use different contexts for the relay connection lifetime and the Connect() call.

fixes https://github.com/nbd-wtf/go-nostr/pull/86
This commit is contained in:
fiatjaf
2023-05-09 17:08:04 -03:00
parent ccbb44989f
commit 9dc674bc02
4 changed files with 24 additions and 12 deletions

View File

@ -34,7 +34,7 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
defer pool.mutex.Unlock() defer pool.mutex.Unlock()
relay, ok := pool.Relays[nm] relay, ok := pool.Relays[nm]
if ok && relay.ConnectionContext.Err() == nil { if ok && relay.connectionContext.Err() == nil {
// already connected, unlock and return // already connected, unlock and return
return relay, nil return relay, nil
} else { } else {

View File

@ -43,7 +43,7 @@ type Relay struct {
Challenges chan string // NIP-42 Challenges Challenges chan string // NIP-42 Challenges
Notices chan string Notices chan string
ConnectionError error ConnectionError error
ConnectionContext context.Context // will be canceled when the connection closes connectionContext context.Context // will be canceled when the connection closes
connectionContextCancel context.CancelFunc connectionContextCancel context.CancelFunc
okCallbacks s.MapOf[string, func(bool, string)] okCallbacks s.MapOf[string, func(bool, string)]
@ -54,11 +54,16 @@ type Relay struct {
AssumeValid bool // this will skip verifying signatures for events received from this relay AssumeValid bool // this will skip verifying signatures for events received from this relay
} }
// NewRelay returns a new relay. The relay connection will be closed when the context is canceled.
func NewRelay(ctx context.Context, url string) *Relay {
return &Relay{URL: NormalizeURL(url), connectionContext: ctx}
}
// RelayConnect returns a relay object connected to url. // RelayConnect returns a relay object connected to url.
// Once successfully connected, cancelling ctx has no effect. // Once successfully connected, cancelling ctx has no effect.
// To close the connection, call r.Close(). // To close the connection, call r.Close().
func RelayConnect(ctx context.Context, url string) (*Relay, error) { func RelayConnect(ctx context.Context, url string) (*Relay, error) {
r := &Relay{URL: NormalizeURL(url)} r := NewRelay(context.Background(), url)
err := r.Connect(ctx) err := r.Connect(ctx)
return r, err return r, err
} }
@ -67,13 +72,20 @@ func (r *Relay) String() string {
return r.URL return r.URL
} }
// Context retrieves the context that is associated with this relay connection.
func (r *Relay) Context() context.Context { return r.connectionContext }
// Connect tries to establish a websocket connection to r.URL. // Connect tries to establish a websocket connection to r.URL.
// If the context expires before the connection is complete, an error is returned. // If the context expires before the connection is complete, an error is returned.
// Once successfully connected, context expiration has no effect: call r.Close // Once successfully connected, context expiration has no effect: call r.Close
// to close the connection. // to close the connection.
//
// The underlying relay connection will use a background context. If you want to
// pass a custom context to the underlying relay connection, use NewRelay() and
// then Relay.Connect().
func (r *Relay) Connect(ctx context.Context) error { func (r *Relay) Connect(ctx context.Context) error {
connectionContext, cancel := context.WithCancel(ctx) connectionContext, cancel := context.WithCancel(ctx)
r.ConnectionContext = connectionContext r.connectionContext = connectionContext
r.connectionContextCancel = cancel r.connectionContextCancel = cancel
if r.URL == "" { if r.URL == "" {
@ -100,7 +112,7 @@ func (r *Relay) Connect(ctx context.Context) error {
// close these channels when the connection is dropped // close these channels when the connection is dropped
go func() { go func() {
<-r.ConnectionContext.Done() <-r.connectionContext.Done()
r.mutex.Lock() r.mutex.Lock()
close(r.Challenges) close(r.Challenges)
close(r.Notices) close(r.Notices)
@ -128,7 +140,7 @@ func (r *Relay) Connect(ctx context.Context) error {
go func() { go func() {
defer cancel() defer cancel()
for { for {
message, err := conn.ReadMessage(r.ConnectionContext) message, err := conn.ReadMessage(r.connectionContext)
if err != nil { if err != nil {
r.ConnectionError = err r.ConnectionError = err
break break
@ -146,7 +158,7 @@ func (r *Relay) Connect(ctx context.Context) error {
// we'll consume ever more memory with each new notice // we'll consume ever more memory with each new notice
go func() { go func() {
r.mutex.RLock() r.mutex.RLock()
if r.ConnectionContext.Err() == nil { if r.connectionContext.Err() == nil {
r.Notices <- string(*env) r.Notices <- string(*env)
} }
r.mutex.RUnlock() r.mutex.RUnlock()
@ -159,7 +171,7 @@ func (r *Relay) Connect(ctx context.Context) error {
// TODO: same as with NoticeEnvelope // TODO: same as with NoticeEnvelope
go func() { go func() {
r.mutex.RLock() r.mutex.RLock()
if r.ConnectionContext.Err() == nil { if r.connectionContext.Err() == nil {
r.Challenges <- *env.Challenge r.Challenges <- *env.Challenge
} }
r.mutex.RUnlock() r.mutex.RUnlock()
@ -276,7 +288,7 @@ func (r *Relay) Publish(ctx context.Context, event Event) (Status, error) {
// but if it happens because okCallback was called then it might be "succeeded" // but if it happens because okCallback was called then it might be "succeeded"
// do not return if okCallback is in process // do not return if okCallback is in process
return status, err return status, err
case <-r.ConnectionContext.Done(): case <-r.connectionContext.Done():
// same as above, but when the relay loses connectivity entirely // same as above, but when the relay loses connectivity entirely
return status, err return status, err
case <-time.After(4 * time.Second): case <-time.After(4 * time.Second):

View File

@ -84,7 +84,7 @@ func (sub *Subscription) Fire() error {
// or when the relay connection is closed // or when the relay connection is closed
go func() { go func() {
<-sub.Relay.ConnectionContext.Done() <-sub.Relay.connectionContext.Done()
// cancel the context -- this will cause the other context cancelation cause above to be called // cancel the context -- this will cause the other context cancelation cause above to be called
sub.cancel() sub.cancel()

View File

@ -29,8 +29,8 @@ func TestSubscribe(t *testing.T) {
events++ events++
case <-sub.EndOfStoredEvents: case <-sub.EndOfStoredEvents:
goto end goto end
case <-rl.ConnectionContext.Done(): case <-rl.Context().Done():
t.Errorf("connection closed: %v", rl.ConnectionContext.Err()) t.Errorf("connection closed: %v", rl.Context().Err())
goto end goto end
case <-timeout: case <-timeout:
t.Errorf("timeout") t.Errorf("timeout")