diff --git a/relay.go b/relay.go index 0d78cd1..fb302b4 100644 --- a/relay.go +++ b/relay.go @@ -186,30 +186,33 @@ func (r *Relay) Connect(ctx context.Context) error { // ping every 29 seconds ticker := time.NewTicker(29 * time.Second) - // queue all messages received from the relay on this - messageHandler := make(chan Envelope) + // this ensures we don't send an event to the Events channel after closing it + eventsChannelCloserMutex := &sync.Mutex{} - // we'll queue all relay actions (handling received messages etc) in a single queue - // such that we can close channels safely without mutex spaghetti + // to be used when the connection is closed + go func() { + <-r.connectionContext.Done() + // close these things when the connection is closed + if r.challenges != nil { + close(r.challenges) + } + if r.notices != nil { + close(r.notices) + } + // stop the ticker + ticker.Stop() + // close all subscriptions + r.Subscriptions.Range(func(_ string, sub *Subscription) bool { + go sub.Unsub() + return true + }) + return + }() + + // queue all write operations here so we don't do mutex spaghetti go func() { for { select { - case <-r.connectionContext.Done(): - // close these things when the connection is closed - if r.challenges != nil { - close(r.challenges) - } - if r.notices != nil { - close(r.notices) - } - // stop the ticker - ticker.Stop() - // close all subscriptions - r.Subscriptions.Range(func(_ string, sub *Subscription) bool { - go sub.Unsub() - return true - }) - return case <-ticker.C: err := wsutil.WriteClientMessage(r.Connection.conn, ws.OpPing, nil) if err != nil { @@ -217,23 +220,27 @@ func (r *Relay) Connect(ctx context.Context) error { r.Close() // this should trigger a context cancelation return } - case envelope := <-messageHandler: - // this will run synchronously in this goroutine - r.HandleRelayMessage(envelope) case writeRequest := <-r.writeQueue: // all write requests will go through this to prevent races if err := r.Connection.WriteMessage(writeRequest.msg); err != nil { writeRequest.answer <- err } close(writeRequest.answer) - case toClose := <-r.subscriptionChannelCloseQueue: - // every time a subscription ends we use this queue to close its Events channel - close(toClose.Events) - toClose.Events = make(chan *Event) } } }() + // every time a subscription ends we use this queue to close its .Events channel + go func() { + for toClose := range r.subscriptionChannelCloseQueue { + eventsChannelCloserMutex.Lock() + close(toClose.Events) + toClose.Events = make(chan *Event) + eventsChannelCloserMutex.Unlock() + } + }() + + // general message reader loop go func() { for { message, err := conn.ReadMessage(r.connectionContext) @@ -249,74 +256,73 @@ func (r *Relay) Connect(ctx context.Context) error { continue } - messageHandler <- envelope + switch env := envelope.(type) { + case *NoticeEnvelope: + // see WithNoticeHandler + if r.notices != nil { + r.notices <- string(*env) + } else { + log.Printf("NOTICE from %s: '%s'\n", r.URL, string(*env)) + } + case *AuthEnvelope: + if env.Challenge == nil { + return + } + // see WithAuthHandler + if r.challenges != nil { + r.challenges <- *env.Challenge + } + case *EventEnvelope: + if env.SubscriptionID == nil { + return + } + if subscription, ok := r.Subscriptions.Load(*env.SubscriptionID); !ok { + // InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID) + return + } else { + // check if the event matches the desired filter, ignore otherwise + if !subscription.Filters.Match(&env.Event) { + InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters, env.Event) + return + } + + // check signature, ignore invalid, except from trusted (AssumeValid) relays + if !r.AssumeValid { + if ok, err := env.Event.CheckSignature(); !ok { + errmsg := "" + if err != nil { + errmsg = err.Error() + } + InfoLogger.Printf("{%s} bad signature: %s\n", r.URL, errmsg) + return + } + } + + go func() { + eventsChannelCloserMutex.Lock() + if subscription.live { + subscription.Events <- &env.Event + } + eventsChannelCloserMutex.Unlock() + }() + } + case *EOSEEnvelope: + if subscription, ok := r.Subscriptions.Load(string(*env)); ok { + subscription.emitEose.Do(func() { + close(subscription.EndOfStoredEvents) + }) + } + case *OKEnvelope: + if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { + okCallback(env.OK, *env.Reason) + } + } } }() return nil } -// HandleRelayMessage handles a message received from a relay. -func (r *Relay) HandleRelayMessage(envelope Envelope) { - switch env := envelope.(type) { - case *NoticeEnvelope: - // see WithNoticeHandler - if r.notices != nil { - r.notices <- string(*env) - } else { - log.Printf("NOTICE from %s: '%s'\n", r.URL, string(*env)) - } - case *AuthEnvelope: - if env.Challenge == nil { - return - } - // see WithAuthHandler - if r.challenges != nil { - r.challenges <- *env.Challenge - } - case *EventEnvelope: - if env.SubscriptionID == nil { - return - } - if subscription, ok := r.Subscriptions.Load(*env.SubscriptionID); !ok { - // InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID) - return - } else { - // check if the event matches the desired filter, ignore otherwise - if !subscription.Filters.Match(&env.Event) { - InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters, env.Event) - return - } - - // check signature, ignore invalid, except from trusted (AssumeValid) relays - if !r.AssumeValid { - if ok, err := env.Event.CheckSignature(); !ok { - errmsg := "" - if err != nil { - errmsg = err.Error() - } - InfoLogger.Printf("{%s} bad signature: %s\n", r.URL, errmsg) - return - } - } - - if subscription.live { - subscription.Events <- &env.Event - } - } - case *EOSEEnvelope: - if subscription, ok := r.Subscriptions.Load(string(*env)); ok { - subscription.emitEose.Do(func() { - close(subscription.EndOfStoredEvents) - }) - } - case *OKEnvelope: - if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { - okCallback(env.OK, *env.Reason) - } - } -} - // Write queues a message to be sent to the relay. func (r *Relay) Write(msg []byte) <-chan error { ch := make(chan error) diff --git a/subscription.go b/subscription.go index b7fbbf2..0726bdd 100644 --- a/subscription.go +++ b/subscription.go @@ -57,20 +57,25 @@ func (sub *Subscription) GetID() string { // Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01. // Unsub() also closes the channel sub.Events and makes a new one. func (sub *Subscription) Unsub() { - id := sub.GetID() + go sub.Close() + sub.live = false + id := sub.GetID() + sub.Relay.Subscriptions.Delete(id) + + // do this so we don't have the possibility of closing the Events channel and then trying to send to it + sub.Relay.subscriptionChannelCloseQueue <- sub +} + +// Close just sends a CLOSE message. You probably want Unsub() instead. +func (sub *Subscription) Close() { if sub.Relay.IsConnected() { + id := sub.GetID() closeMsg := CloseEnvelope(id) closeb, _ := (&closeMsg).MarshalJSON() debugLog("{%s} sending %v", sub.Relay.URL, closeb) sub.Relay.Write(closeb) } - - sub.live = false - sub.Relay.Subscriptions.Delete(id) - - // do this so we don't have the possibility of closing the Events channel and then trying to send to it - sub.Relay.subscriptionChannelCloseQueue <- sub } // Sub sets sub.Filters and then calls sub.Fire(ctx).