diff --git a/connection.go b/connection.go index a7a0bc0..445a43f 100644 --- a/connection.go +++ b/connection.go @@ -9,7 +9,6 @@ import ( "io" "net" "net/http" - "sync" "github.com/gobwas/httphead" "github.com/gobwas/ws" @@ -26,7 +25,6 @@ type Connection struct { flateWriter *wsflate.Writer writer *wsutil.Writer msgState *wsflate.MessageState - mutex sync.Mutex } func NewConnection(ctx context.Context, url string, requestHeader http.Header) (*Connection, error) { @@ -100,17 +98,7 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) ( }, nil } -func (c *Connection) Ping() error { - c.mutex.Lock() - defer c.mutex.Unlock() - - return wsutil.WriteClientMessage(c.conn, ws.OpPing, nil) -} - func (c *Connection) WriteMessage(data []byte) error { - c.mutex.Lock() - defer c.mutex.Unlock() - if c.msgState.IsCompressed() && c.enableCompression { c.flateWriter.Reset(c.writer) if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil { diff --git a/pool.go b/pool.go index b26c868..972e823 100644 --- a/pool.go +++ b/pool.go @@ -35,7 +35,7 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) { defer pool.mutex.Unlock() relay, ok := pool.Relays[nm] - if ok && relay.connectionContext.Err() == nil { + if ok && relay.IsConnected() { // already connected, unlock and return return relay, nil } else { diff --git a/relay.go b/relay.go index d0df654..448418e 100644 --- a/relay.go +++ b/relay.go @@ -3,11 +3,14 @@ package nostr import ( "context" "fmt" + "log" "net/http" "sync" "sync/atomic" "time" + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" "github.com/puzpuzpuz/xsync" ) @@ -41,41 +44,104 @@ type Relay struct { Connection *Connection Subscriptions *xsync.MapOf[string, *Subscription] - Challenges chan string // NIP-42 Challenges - Notices chan string ConnectionError error connectionContext context.Context // will be canceled when the connection closes connectionContextCancel context.CancelFunc - okCallbacks *xsync.MapOf[string, func(bool, string)] - mutex sync.RWMutex + challenges chan string // NIP-42 challenges + notices chan string // NIP-01 NOTICEs + okCallbacks *xsync.MapOf[string, func(bool, string)] + writeQueue chan writeRequest + subscriptionChannelCloseQueue chan *Subscription // custom things that aren't often used // AssumeValid bool // this will skip verifying signatures for events received from this relay } +type writeRequest struct { + msg []byte + answer chan error +} + // NewRelay returns a new relay. The relay connection will be closed when the context is canceled. -func NewRelay(ctx context.Context, url string) *Relay { +func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Relay { ctx, cancel := context.WithCancel(ctx) - return &Relay{ + r := &Relay{ URL: NormalizeURL(url), connectionContext: ctx, connectionContextCancel: cancel, Subscriptions: xsync.NewMapOf[*Subscription](), okCallbacks: xsync.NewMapOf[func(bool, string)](), + writeQueue: make(chan writeRequest), } + + for _, opt := range opts { + switch o := opt.(type) { + case WithNoticeHandler: + r.notices = make(chan string) + go func() { + for notice := range r.notices { + o(notice) + } + }() + case WithAuthHandler: + r.challenges = make(chan string) + go func() { + for challenge := range r.challenges { + authEvent := Event{ + CreatedAt: Now(), + Kind: 22242, + Tags: Tags{ + Tag{"relay", url}, + Tag{"challenge", challenge}, + }, + Content: "", + } + if ok := o(r.connectionContext, &authEvent); ok { + r.Auth(r.connectionContext, authEvent) + } + } + }() + } + } + + return r } // RelayConnect returns a relay object connected to url. // Once successfully connected, cancelling ctx has no effect. // To close the connection, call r.Close(). -func RelayConnect(ctx context.Context, url string) (*Relay, error) { - r := NewRelay(context.Background(), url) +func RelayConnect(ctx context.Context, url string, opts ...RelayOption) (*Relay, error) { + r := NewRelay(context.Background(), url, opts...) err := r.Connect(ctx) return r, err } +// When instantiating relay connections, some options may be passed. +// RelayOption is the type of the argument passed for that. +// Some examples of this are WithNoticeHandler and WithAuthHandler. +type RelayOption interface { + IsRelayOption() +} + +// WithNoticeHandler just takes notices and is expected to do something with them. +// when not given, defaults to logging the notices. +type WithNoticeHandler func(notice string) + +func (_ WithNoticeHandler) IsRelayOption() {} + +var _ RelayOption = (WithNoticeHandler)(nil) + +// WithAuthHandler takes an auth event and expects it to be signed. +// when not given, AUTH messages from relays are ignored. +type WithAuthHandler func(ctx context.Context, authEvent *Event) (ok bool) + +func (_ WithAuthHandler) IsRelayOption() {} + +var _ RelayOption = (WithAuthHandler)(nil) + +// String() just prints the relay URL. func (r *Relay) String() string { return r.URL } @@ -83,6 +149,9 @@ func (r *Relay) String() string { // Context retrieves the context that is associated with this relay connection. func (r *Relay) Context() context.Context { return r.connectionContext } +// IsConnected returns true if the connection to this relay seems to be active. +func (r *Relay) IsConnected() bool { return r.connectionContext.Err() == nil } + // Connect tries to establish a websocket connection to r.URL. // If the context expires before the connection is complete, an error is returned. // Once successfully connected, context expiration has no effect: call r.Close @@ -113,36 +182,57 @@ func (r *Relay) Connect(ctx context.Context) error { } r.Connection = conn - r.Challenges = make(chan string) - r.Notices = make(chan string) - - // close these channels when the connection is dropped - go func() { - <-r.connectionContext.Done() - r.mutex.Lock() - close(r.Challenges) - close(r.Notices) - r.mutex.Unlock() - }() - // ping every 29 seconds + ticker := time.NewTicker(29 * time.Second) + + // queue all messages received from the relay on this + messageHandler := make(chan Envelope) + + // we'll queue all relay actions (handling received messages etc) in a single queue + // such that we can close channels safely without mutex spaghetti go func() { - ticker := time.NewTicker(29 * time.Second) - defer ticker.Stop() for { select { + case <-r.connectionContext.Done(): + // close these things when the connection is closed + if r.challenges != nil { + close(r.challenges) + } + if r.notices != nil { + close(r.notices) + } + // stop the ticker + ticker.Stop() + // close all subscriptions + r.Subscriptions.Range(func(_ string, sub *Subscription) bool { + sub.Unsub() + return true + }) + return case <-ticker.C: - err := conn.Ping() + err := wsutil.WriteClientMessage(r.Connection.conn, ws.OpPing, nil) if err != nil { InfoLogger.Printf("{%s} error writing ping: %v; closing websocket", r.URL, err) - r.Close() + r.Close() // this should trigger a context cancelation return } + case envelope := <-messageHandler: + // this will run synchronously in this goroutine + r.HandleRelayMessage(envelope) + case writeRequest := <-r.writeQueue: + // all write requests will go through this to prevent races + if err := r.Connection.WriteMessage(writeRequest.msg); err != nil { + writeRequest.answer <- err + } + close(writeRequest.answer) + case toClose := <-r.subscriptionChannelCloseQueue: + // every time a subscription ends we use this queue to close its Events channel + close(toClose.Events) + toClose.Events = make(chan *Event) } } }() - // handling received messages go func() { for { message, err := conn.ReadMessage(r.connectionContext) @@ -151,91 +241,88 @@ func (r *Relay) Connect(ctx context.Context) error { break } + debugLog("{%s} %v\n", r.URL, message) + envelope := ParseMessage(message) if envelope == nil { continue } - switch env := envelope.(type) { - case *NoticeEnvelope: - debugLog("{%s} %v\n", r.URL, message) - // TODO: improve this, otherwise if the application doesn't read the notices - // we'll consume ever more memory with each new notice - go func() { - r.mutex.RLock() - if r.connectionContext.Err() == nil { - r.Notices <- string(*env) - } - r.mutex.RUnlock() - }() - case *AuthEnvelope: - debugLog("{%s} %v\n", r.URL, message) - if env.Challenge == nil { - continue - } - // TODO: same as with NoticeEnvelope - go func() { - r.mutex.RLock() - if r.connectionContext.Err() == nil { - r.Challenges <- *env.Challenge - } - r.mutex.RUnlock() - }() - case *EventEnvelope: - debugLog("{%s} %v\n", r.URL, message) - if env.SubscriptionID == nil { - continue - } - if subscription, ok := r.Subscriptions.Load(*env.SubscriptionID); !ok { - // InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID) - continue - } else { - func() { - // check if the event matches the desired filter, ignore otherwise - if !subscription.Filters.Match(&env.Event) { - InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters, env.Event) - return - } - - subscription.mutex.Lock() - defer subscription.mutex.Unlock() - if subscription.stopped { - return - } - - // check signature, ignore invalid, except from trusted (AssumeValid) relays - if !r.AssumeValid { - if ok, err := env.Event.CheckSignature(); !ok { - errmsg := "" - if err != nil { - errmsg = err.Error() - } - InfoLogger.Printf("{%s} bad signature: %s\n", r.URL, errmsg) - return - } - } - - subscription.Events <- &env.Event - }() - } - case *EOSEEnvelope: - debugLog("{%s} %v\n", r.URL, message) - if subscription, ok := r.Subscriptions.Load(string(*env)); ok { - subscription.emitEose.Do(func() { - subscription.EndOfStoredEvents <- struct{}{} - }) - } - case *OKEnvelope: - if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { - okCallback(env.OK, *env.Reason) - } - } + messageHandler <- envelope } }() return nil } +// HandleRelayMessage handles a message received from a relay. +func (r *Relay) HandleRelayMessage(envelope Envelope) { + switch env := envelope.(type) { + case *NoticeEnvelope: + // see WithNoticeHandler + if r.notices != nil { + r.notices <- string(*env) + } else { + log.Printf("NOTICE from %s: '%s'\n", r.URL, string(*env)) + } + case *AuthEnvelope: + if env.Challenge == nil { + return + } + // see WithAuthHandler + if r.challenges != nil { + r.challenges <- *env.Challenge + } + case *EventEnvelope: + if env.SubscriptionID == nil { + return + } + if subscription, ok := r.Subscriptions.Load(*env.SubscriptionID); !ok { + // InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID) + return + } else { + // check if the event matches the desired filter, ignore otherwise + if !subscription.Filters.Match(&env.Event) { + InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters, env.Event) + return + } + + // check signature, ignore invalid, except from trusted (AssumeValid) relays + if !r.AssumeValid { + if ok, err := env.Event.CheckSignature(); !ok { + errmsg := "" + if err != nil { + errmsg = err.Error() + } + InfoLogger.Printf("{%s} bad signature: %s\n", r.URL, errmsg) + return + } + } + + if subscription.live { + subscription.Events <- &env.Event + } + } + case *EOSEEnvelope: + if subscription, ok := r.Subscriptions.Load(string(*env)); ok { + subscription.emitEose.Do(func() { + close(subscription.EndOfStoredEvents) + }) + } + case *OKEnvelope: + if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { + okCallback(env.OK, *env.Reason) + } + } +} + +// Write queues a message to be sent to the relay. +func (r *Relay) Write(msg []byte) error { + ch := make(chan error) + r.writeQueue <- writeRequest{msg: msg, answer: ch} + return <-ch +} + // Publish sends an "EVENT" command to the relay r as in NIP-01. // Status can be: success, failed, or sent (no response from relay before ctx times out). func (r *Relay) Publish(ctx context.Context, event Event) (Status, error) { @@ -276,7 +363,7 @@ func (r *Relay) Publish(ctx context.Context, event Event) (Status, error) { envb, _ := EventEnvelope{Event: event}.MarshalJSON() debugLog("{%s} sending %v\n", r.URL, envb) status = PublishStatusSent - if err := r.Connection.WriteMessage(envb); err != nil { + if err := r.Write(envb); err != nil { status = PublishStatusFailed return status, err } @@ -335,7 +422,7 @@ func (r *Relay) Auth(ctx context.Context, event Event) (Status, error) { // send AUTH authResponse, _ := AuthEnvelope{Event: event}.MarshalJSON() debugLog("{%s} sending %v\n", r.URL, authResponse) - if err := r.Connection.WriteMessage(authResponse); err != nil { + if err := r.Write(authResponse); err != nil { // status will be "failed" return status, err } @@ -356,13 +443,8 @@ func (r *Relay) Auth(ctx context.Context, event Event) (Status, error) { // Subscribe sends a "REQ" command to the relay r as in NIP-01. // Events are returned through the channel sub.Events. // The subscription is closed when context ctx is cancelled ("CLOSE" in NIP-01). -func (r *Relay) Subscribe(ctx context.Context, filters Filters) (*Subscription, error) { - if r.Connection == nil { - panic(fmt.Errorf("must call .Connect() first before calling .Subscribe()")) - } - - sub := r.PrepareSubscription(ctx) - sub.Filters = filters +func (r *Relay) Subscribe(ctx context.Context, filters Filters, opts ...SubscriptionOption) (*Subscription, error) { + sub := r.PrepareSubscription(ctx, filters, opts...) if err := sub.Fire(); err != nil { return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filters, r.URL, err) @@ -371,8 +453,47 @@ func (r *Relay) Subscribe(ctx context.Context, filters Filters) (*Subscription, return sub, nil } -func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error) { - sub, err := r.Subscribe(ctx, Filters{filter}) +// PrepareSubscription creates a subscription, but doesn't fire it. +func (r *Relay) PrepareSubscription(ctx context.Context, filters Filters, opts ...SubscriptionOption) *Subscription { + if r.Connection == nil { + panic(fmt.Errorf("must call .Connect() first before calling .Subscribe()")) + } + + current := subscriptionIdCounter.Add(1) + ctx, cancel := context.WithCancel(ctx) + + sub := &Subscription{ + Relay: r, + Context: ctx, + cancel: cancel, + conn: r.Connection, + counter: int(current), + Events: make(chan *Event), + EndOfStoredEvents: make(chan struct{}), + Filters: filters, + } + + for _, opt := range opts { + switch o := opt.(type) { + case WithLabel: + sub.label = string(o) + } + } + + id := sub.GetID() + r.Subscriptions.Store(id, sub) + + // the subscription ends once the context is canceled + go func() { + <-sub.Context.Done() + sub.Unsub() + }() + + return sub +} + +func (r *Relay) QuerySync(ctx context.Context, filter Filter, opts ...SubscriptionOption) ([]*Event, error) { + sub, err := r.Subscribe(ctx, Filters{filter}, opts...) if err != nil { return nil, err } @@ -403,21 +524,6 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error) } } -func (r *Relay) PrepareSubscription(ctx context.Context) *Subscription { - current := subscriptionIdCounter.Add(1) - ctx, cancel := context.WithCancel(ctx) - - return &Subscription{ - Relay: r, - Context: ctx, - cancel: cancel, - conn: r.Connection, - counter: int(current), - Events: make(chan *Event), - EndOfStoredEvents: make(chan struct{}, 1), - } -} - func (r *Relay) Close() error { if r.connectionContextCancel == nil { return fmt.Errorf("relay not connected") diff --git a/subscription.go b/subscription.go index 7c78d19..8a4bb74 100644 --- a/subscription.go +++ b/subscription.go @@ -11,16 +11,22 @@ type Subscription struct { label string counter int conn *Connection - mutex sync.Mutex - Relay *Relay - Filters Filters - Events chan *Event + Relay *Relay + Filters Filters + + // the Events channel emits all EVENTs that come in a Subscription + // will be closed when the subscription ends + Events chan *Event + + // the EndOfStoredEvents channel gets closed when an EOSE comes for that subscription EndOfStoredEvents chan struct{} - Context context.Context - cancel context.CancelFunc - stopped bool + // Context will be .Done() when the subscription ends + Context context.Context + + live bool + cancel context.CancelFunc emitEose sync.Once } @@ -29,35 +35,43 @@ type EventMessage struct { Relay string } -// SetLabel puts a label on the subscription that is prepended to the id that is sent to relays, -// -// it's only useful for debugging and sanity purposes. -func (sub *Subscription) SetLabel(label string) { - sub.label = label +// When instantiating relay connections, some options may be passed. +// SubscriptionOption is the type of the argument passed for that. +// Some examples are WithLabel. +type SubscriptionOption interface { + IsSubscriptionOption() } -// GetID return the Nostr subscription ID as given to the relay, it will be a sequential number, stringified. +// WithLabel puts a label on the subscription (it is prepended to the automatic id) that is sent to relays. +type WithLabel string + +func (_ WithLabel) IsSubscriptionOption() {} + +var _ SubscriptionOption = (WithLabel)("") + +// GetID return the Nostr subscription ID as given to the Relay +// it is a concatenation of the label and a serial number. func (sub *Subscription) GetID() string { return sub.label + ":" + strconv.Itoa(sub.counter) } // Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01. -// Unsub() also closes the channel sub.Events. +// Unsub() also closes the channel sub.Events and makes a new one. func (sub *Subscription) Unsub() { - sub.mutex.Lock() - defer sub.mutex.Unlock() - id := sub.GetID() - closeMsg := CloseEnvelope(id) - closeb, _ := (&closeMsg).MarshalJSON() - debugLog("{%s} sending %v", sub.Relay.URL, closeb) - sub.conn.WriteMessage(closeb) + + if sub.Relay.IsConnected() { + closeMsg := CloseEnvelope(id) + closeb, _ := (&closeMsg).MarshalJSON() + debugLog("{%s} sending %v", sub.Relay.URL, closeb) + sub.Relay.Write(closeb) + } + + sub.live = false sub.Relay.Subscriptions.Delete(id) - if !sub.stopped && sub.Events != nil { - close(sub.Events) - } - sub.stopped = true + // do this so we don't have the possibility of closing the Events channel and then trying to send to it + sub.Relay.subscriptionChannelCloseQueue <- sub } // Sub sets sub.Filters and then calls sub.Fire(ctx). @@ -70,28 +84,15 @@ func (sub *Subscription) Sub(ctx context.Context, filters Filters) { // Fire sends the "REQ" command to the relay. func (sub *Subscription) Fire() error { id := sub.GetID() - sub.Relay.Subscriptions.Store(id, sub) reqb, _ := ReqEnvelope{id, sub.Filters}.MarshalJSON() debugLog("{%s} sending %v", sub.Relay.URL, reqb) - if err := sub.conn.WriteMessage(reqb); err != nil { + + sub.live = true + if err := sub.Relay.Write(reqb); err != nil { sub.cancel() return fmt.Errorf("failed to write: %w", err) } - // the subscription ends once the context is canceled - go func() { - <-sub.Context.Done() - sub.Unsub() - }() - - // or when the relay connection is closed - go func() { - <-sub.Relay.connectionContext.Done() - - // cancel the context -- this will cause the other context cancelation cause above to be called - sub.cancel() - }() - return nil }