get rid of mutexes and use a single loop to prevent races.

in the meantime change the API to makes a little less error-prone.
This commit is contained in:
fiatjaf
2023-06-21 19:55:40 -03:00
parent af4fc98fc2
commit 030c1d1898
4 changed files with 274 additions and 179 deletions

View File

@ -9,7 +9,6 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"sync"
"github.com/gobwas/httphead" "github.com/gobwas/httphead"
"github.com/gobwas/ws" "github.com/gobwas/ws"
@ -26,7 +25,6 @@ type Connection struct {
flateWriter *wsflate.Writer flateWriter *wsflate.Writer
writer *wsutil.Writer writer *wsutil.Writer
msgState *wsflate.MessageState msgState *wsflate.MessageState
mutex sync.Mutex
} }
func NewConnection(ctx context.Context, url string, requestHeader http.Header) (*Connection, error) { func NewConnection(ctx context.Context, url string, requestHeader http.Header) (*Connection, error) {
@ -100,17 +98,7 @@ func NewConnection(ctx context.Context, url string, requestHeader http.Header) (
}, nil }, nil
} }
func (c *Connection) Ping() error {
c.mutex.Lock()
defer c.mutex.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpPing, nil)
}
func (c *Connection) WriteMessage(data []byte) error { func (c *Connection) WriteMessage(data []byte) error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.msgState.IsCompressed() && c.enableCompression { if c.msgState.IsCompressed() && c.enableCompression {
c.flateWriter.Reset(c.writer) c.flateWriter.Reset(c.writer)
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil { if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {

View File

@ -35,7 +35,7 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
defer pool.mutex.Unlock() defer pool.mutex.Unlock()
relay, ok := pool.Relays[nm] relay, ok := pool.Relays[nm]
if ok && relay.connectionContext.Err() == nil { if ok && relay.IsConnected() {
// already connected, unlock and return // already connected, unlock and return
return relay, nil return relay, nil
} else { } else {

274
relay.go
View File

@ -3,11 +3,14 @@ package nostr
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net/http" "net/http"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/puzpuzpuz/xsync" "github.com/puzpuzpuz/xsync"
) )
@ -41,41 +44,104 @@ type Relay struct {
Connection *Connection Connection *Connection
Subscriptions *xsync.MapOf[string, *Subscription] Subscriptions *xsync.MapOf[string, *Subscription]
Challenges chan string // NIP-42 Challenges
Notices chan string
ConnectionError error ConnectionError error
connectionContext context.Context // will be canceled when the connection closes connectionContext context.Context // will be canceled when the connection closes
connectionContextCancel context.CancelFunc connectionContextCancel context.CancelFunc
challenges chan string // NIP-42 challenges
notices chan string // NIP-01 NOTICEs
okCallbacks *xsync.MapOf[string, func(bool, string)] okCallbacks *xsync.MapOf[string, func(bool, string)]
mutex sync.RWMutex writeQueue chan writeRequest
subscriptionChannelCloseQueue chan *Subscription
// custom things that aren't often used // custom things that aren't often used
// //
AssumeValid bool // this will skip verifying signatures for events received from this relay AssumeValid bool // this will skip verifying signatures for events received from this relay
} }
type writeRequest struct {
msg []byte
answer chan error
}
// NewRelay returns a new relay. The relay connection will be closed when the context is canceled. // NewRelay returns a new relay. The relay connection will be closed when the context is canceled.
func NewRelay(ctx context.Context, url string) *Relay { func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Relay {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return &Relay{ r := &Relay{
URL: NormalizeURL(url), URL: NormalizeURL(url),
connectionContext: ctx, connectionContext: ctx,
connectionContextCancel: cancel, connectionContextCancel: cancel,
Subscriptions: xsync.NewMapOf[*Subscription](), Subscriptions: xsync.NewMapOf[*Subscription](),
okCallbacks: xsync.NewMapOf[func(bool, string)](), okCallbacks: xsync.NewMapOf[func(bool, string)](),
writeQueue: make(chan writeRequest),
} }
for _, opt := range opts {
switch o := opt.(type) {
case WithNoticeHandler:
r.notices = make(chan string)
go func() {
for notice := range r.notices {
o(notice)
}
}()
case WithAuthHandler:
r.challenges = make(chan string)
go func() {
for challenge := range r.challenges {
authEvent := Event{
CreatedAt: Now(),
Kind: 22242,
Tags: Tags{
Tag{"relay", url},
Tag{"challenge", challenge},
},
Content: "",
}
if ok := o(r.connectionContext, &authEvent); ok {
r.Auth(r.connectionContext, authEvent)
}
}
}()
}
}
return r
} }
// RelayConnect returns a relay object connected to url. // RelayConnect returns a relay object connected to url.
// Once successfully connected, cancelling ctx has no effect. // Once successfully connected, cancelling ctx has no effect.
// To close the connection, call r.Close(). // To close the connection, call r.Close().
func RelayConnect(ctx context.Context, url string) (*Relay, error) { func RelayConnect(ctx context.Context, url string, opts ...RelayOption) (*Relay, error) {
r := NewRelay(context.Background(), url) r := NewRelay(context.Background(), url, opts...)
err := r.Connect(ctx) err := r.Connect(ctx)
return r, err return r, err
} }
// When instantiating relay connections, some options may be passed.
// RelayOption is the type of the argument passed for that.
// Some examples of this are WithNoticeHandler and WithAuthHandler.
type RelayOption interface {
IsRelayOption()
}
// WithNoticeHandler just takes notices and is expected to do something with them.
// when not given, defaults to logging the notices.
type WithNoticeHandler func(notice string)
func (_ WithNoticeHandler) IsRelayOption() {}
var _ RelayOption = (WithNoticeHandler)(nil)
// WithAuthHandler takes an auth event and expects it to be signed.
// when not given, AUTH messages from relays are ignored.
type WithAuthHandler func(ctx context.Context, authEvent *Event) (ok bool)
func (_ WithAuthHandler) IsRelayOption() {}
var _ RelayOption = (WithAuthHandler)(nil)
// String() just prints the relay URL.
func (r *Relay) String() string { func (r *Relay) String() string {
return r.URL return r.URL
} }
@ -83,6 +149,9 @@ func (r *Relay) String() string {
// Context retrieves the context that is associated with this relay connection. // Context retrieves the context that is associated with this relay connection.
func (r *Relay) Context() context.Context { return r.connectionContext } func (r *Relay) Context() context.Context { return r.connectionContext }
// IsConnected returns true if the connection to this relay seems to be active.
func (r *Relay) IsConnected() bool { return r.connectionContext.Err() == nil }
// Connect tries to establish a websocket connection to r.URL. // Connect tries to establish a websocket connection to r.URL.
// If the context expires before the connection is complete, an error is returned. // If the context expires before the connection is complete, an error is returned.
// Once successfully connected, context expiration has no effect: call r.Close // Once successfully connected, context expiration has no effect: call r.Close
@ -113,36 +182,57 @@ func (r *Relay) Connect(ctx context.Context) error {
} }
r.Connection = conn r.Connection = conn
r.Challenges = make(chan string)
r.Notices = make(chan string)
// close these channels when the connection is dropped
go func() {
<-r.connectionContext.Done()
r.mutex.Lock()
close(r.Challenges)
close(r.Notices)
r.mutex.Unlock()
}()
// ping every 29 seconds // ping every 29 seconds
go func() {
ticker := time.NewTicker(29 * time.Second) ticker := time.NewTicker(29 * time.Second)
defer ticker.Stop()
// queue all messages received from the relay on this
messageHandler := make(chan Envelope)
// we'll queue all relay actions (handling received messages etc) in a single queue
// such that we can close channels safely without mutex spaghetti
go func() {
for { for {
select { select {
case <-r.connectionContext.Done():
// close these things when the connection is closed
if r.challenges != nil {
close(r.challenges)
}
if r.notices != nil {
close(r.notices)
}
// stop the ticker
ticker.Stop()
// close all subscriptions
r.Subscriptions.Range(func(_ string, sub *Subscription) bool {
sub.Unsub()
return true
})
return
case <-ticker.C: case <-ticker.C:
err := conn.Ping() err := wsutil.WriteClientMessage(r.Connection.conn, ws.OpPing, nil)
if err != nil { if err != nil {
InfoLogger.Printf("{%s} error writing ping: %v; closing websocket", r.URL, err) InfoLogger.Printf("{%s} error writing ping: %v; closing websocket", r.URL, err)
r.Close() r.Close() // this should trigger a context cancelation
return return
} }
case envelope := <-messageHandler:
// this will run synchronously in this goroutine
r.HandleRelayMessage(envelope)
case writeRequest := <-r.writeQueue:
// all write requests will go through this to prevent races
if err := r.Connection.WriteMessage(writeRequest.msg); err != nil {
writeRequest.answer <- err
}
close(writeRequest.answer)
case toClose := <-r.subscriptionChannelCloseQueue:
// every time a subscription ends we use this queue to close its Events channel
close(toClose.Events)
toClose.Events = make(chan *Event)
} }
} }
}() }()
// handling received messages
go func() { go func() {
for { for {
message, err := conn.ReadMessage(r.connectionContext) message, err := conn.ReadMessage(r.connectionContext)
@ -151,58 +241,52 @@ func (r *Relay) Connect(ctx context.Context) error {
break break
} }
debugLog("{%s} %v\n", r.URL, message)
envelope := ParseMessage(message) envelope := ParseMessage(message)
if envelope == nil { if envelope == nil {
continue continue
} }
messageHandler <- envelope
}
}()
return nil
}
// HandleRelayMessage handles a message received from a relay.
func (r *Relay) HandleRelayMessage(envelope Envelope) {
switch env := envelope.(type) { switch env := envelope.(type) {
case *NoticeEnvelope: case *NoticeEnvelope:
debugLog("{%s} %v\n", r.URL, message) // see WithNoticeHandler
// TODO: improve this, otherwise if the application doesn't read the notices if r.notices != nil {
// we'll consume ever more memory with each new notice r.notices <- string(*env)
go func() { } else {
r.mutex.RLock() log.Printf("NOTICE from %s: '%s'\n", r.URL, string(*env))
if r.connectionContext.Err() == nil {
r.Notices <- string(*env)
} }
r.mutex.RUnlock()
}()
case *AuthEnvelope: case *AuthEnvelope:
debugLog("{%s} %v\n", r.URL, message)
if env.Challenge == nil { if env.Challenge == nil {
continue return
} }
// TODO: same as with NoticeEnvelope // see WithAuthHandler
go func() { if r.challenges != nil {
r.mutex.RLock() r.challenges <- *env.Challenge
if r.connectionContext.Err() == nil {
r.Challenges <- *env.Challenge
} }
r.mutex.RUnlock()
}()
case *EventEnvelope: case *EventEnvelope:
debugLog("{%s} %v\n", r.URL, message)
if env.SubscriptionID == nil { if env.SubscriptionID == nil {
continue return
} }
if subscription, ok := r.Subscriptions.Load(*env.SubscriptionID); !ok { if subscription, ok := r.Subscriptions.Load(*env.SubscriptionID); !ok {
// InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID) // InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID)
continue return
} else { } else {
func() {
// check if the event matches the desired filter, ignore otherwise // check if the event matches the desired filter, ignore otherwise
if !subscription.Filters.Match(&env.Event) { if !subscription.Filters.Match(&env.Event) {
InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters, env.Event) InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters, env.Event)
return return
} }
subscription.mutex.Lock()
defer subscription.mutex.Unlock()
if subscription.stopped {
return
}
// check signature, ignore invalid, except from trusted (AssumeValid) relays // check signature, ignore invalid, except from trusted (AssumeValid) relays
if !r.AssumeValid { if !r.AssumeValid {
if ok, err := env.Event.CheckSignature(); !ok { if ok, err := env.Event.CheckSignature(); !ok {
@ -215,14 +299,14 @@ func (r *Relay) Connect(ctx context.Context) error {
} }
} }
if subscription.live {
subscription.Events <- &env.Event subscription.Events <- &env.Event
}() }
} }
case *EOSEEnvelope: case *EOSEEnvelope:
debugLog("{%s} %v\n", r.URL, message)
if subscription, ok := r.Subscriptions.Load(string(*env)); ok { if subscription, ok := r.Subscriptions.Load(string(*env)); ok {
subscription.emitEose.Do(func() { subscription.emitEose.Do(func() {
subscription.EndOfStoredEvents <- struct{}{} close(subscription.EndOfStoredEvents)
}) })
} }
case *OKEnvelope: case *OKEnvelope:
@ -230,10 +314,13 @@ func (r *Relay) Connect(ctx context.Context) error {
okCallback(env.OK, *env.Reason) okCallback(env.OK, *env.Reason)
} }
} }
} }
}()
return nil // Write queues a message to be sent to the relay.
func (r *Relay) Write(msg []byte) error {
ch := make(chan error)
r.writeQueue <- writeRequest{msg: msg, answer: ch}
return <-ch
} }
// Publish sends an "EVENT" command to the relay r as in NIP-01. // Publish sends an "EVENT" command to the relay r as in NIP-01.
@ -276,7 +363,7 @@ func (r *Relay) Publish(ctx context.Context, event Event) (Status, error) {
envb, _ := EventEnvelope{Event: event}.MarshalJSON() envb, _ := EventEnvelope{Event: event}.MarshalJSON()
debugLog("{%s} sending %v\n", r.URL, envb) debugLog("{%s} sending %v\n", r.URL, envb)
status = PublishStatusSent status = PublishStatusSent
if err := r.Connection.WriteMessage(envb); err != nil { if err := r.Write(envb); err != nil {
status = PublishStatusFailed status = PublishStatusFailed
return status, err return status, err
} }
@ -335,7 +422,7 @@ func (r *Relay) Auth(ctx context.Context, event Event) (Status, error) {
// send AUTH // send AUTH
authResponse, _ := AuthEnvelope{Event: event}.MarshalJSON() authResponse, _ := AuthEnvelope{Event: event}.MarshalJSON()
debugLog("{%s} sending %v\n", r.URL, authResponse) debugLog("{%s} sending %v\n", r.URL, authResponse)
if err := r.Connection.WriteMessage(authResponse); err != nil { if err := r.Write(authResponse); err != nil {
// status will be "failed" // status will be "failed"
return status, err return status, err
} }
@ -356,13 +443,8 @@ func (r *Relay) Auth(ctx context.Context, event Event) (Status, error) {
// Subscribe sends a "REQ" command to the relay r as in NIP-01. // Subscribe sends a "REQ" command to the relay r as in NIP-01.
// Events are returned through the channel sub.Events. // Events are returned through the channel sub.Events.
// The subscription is closed when context ctx is cancelled ("CLOSE" in NIP-01). // The subscription is closed when context ctx is cancelled ("CLOSE" in NIP-01).
func (r *Relay) Subscribe(ctx context.Context, filters Filters) (*Subscription, error) { func (r *Relay) Subscribe(ctx context.Context, filters Filters, opts ...SubscriptionOption) (*Subscription, error) {
if r.Connection == nil { sub := r.PrepareSubscription(ctx, filters, opts...)
panic(fmt.Errorf("must call .Connect() first before calling .Subscribe()"))
}
sub := r.PrepareSubscription(ctx)
sub.Filters = filters
if err := sub.Fire(); err != nil { if err := sub.Fire(); err != nil {
return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filters, r.URL, err) return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filters, r.URL, err)
@ -371,8 +453,47 @@ func (r *Relay) Subscribe(ctx context.Context, filters Filters) (*Subscription,
return sub, nil return sub, nil
} }
func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error) { // PrepareSubscription creates a subscription, but doesn't fire it.
sub, err := r.Subscribe(ctx, Filters{filter}) func (r *Relay) PrepareSubscription(ctx context.Context, filters Filters, opts ...SubscriptionOption) *Subscription {
if r.Connection == nil {
panic(fmt.Errorf("must call .Connect() first before calling .Subscribe()"))
}
current := subscriptionIdCounter.Add(1)
ctx, cancel := context.WithCancel(ctx)
sub := &Subscription{
Relay: r,
Context: ctx,
cancel: cancel,
conn: r.Connection,
counter: int(current),
Events: make(chan *Event),
EndOfStoredEvents: make(chan struct{}),
Filters: filters,
}
for _, opt := range opts {
switch o := opt.(type) {
case WithLabel:
sub.label = string(o)
}
}
id := sub.GetID()
r.Subscriptions.Store(id, sub)
// the subscription ends once the context is canceled
go func() {
<-sub.Context.Done()
sub.Unsub()
}()
return sub
}
func (r *Relay) QuerySync(ctx context.Context, filter Filter, opts ...SubscriptionOption) ([]*Event, error) {
sub, err := r.Subscribe(ctx, Filters{filter}, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -403,21 +524,6 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error)
} }
} }
func (r *Relay) PrepareSubscription(ctx context.Context) *Subscription {
current := subscriptionIdCounter.Add(1)
ctx, cancel := context.WithCancel(ctx)
return &Subscription{
Relay: r,
Context: ctx,
cancel: cancel,
conn: r.Connection,
counter: int(current),
Events: make(chan *Event),
EndOfStoredEvents: make(chan struct{}, 1),
}
}
func (r *Relay) Close() error { func (r *Relay) Close() error {
if r.connectionContextCancel == nil { if r.connectionContextCancel == nil {
return fmt.Errorf("relay not connected") return fmt.Errorf("relay not connected")

View File

@ -11,16 +11,22 @@ type Subscription struct {
label string label string
counter int counter int
conn *Connection conn *Connection
mutex sync.Mutex
Relay *Relay Relay *Relay
Filters Filters Filters Filters
Events chan *Event
EndOfStoredEvents chan struct{}
Context context.Context
cancel context.CancelFunc
stopped bool // the Events channel emits all EVENTs that come in a Subscription
// will be closed when the subscription ends
Events chan *Event
// the EndOfStoredEvents channel gets closed when an EOSE comes for that subscription
EndOfStoredEvents chan struct{}
// Context will be .Done() when the subscription ends
Context context.Context
live bool
cancel context.CancelFunc
emitEose sync.Once emitEose sync.Once
} }
@ -29,35 +35,43 @@ type EventMessage struct {
Relay string Relay string
} }
// SetLabel puts a label on the subscription that is prepended to the id that is sent to relays, // When instantiating relay connections, some options may be passed.
// // SubscriptionOption is the type of the argument passed for that.
// it's only useful for debugging and sanity purposes. // Some examples are WithLabel.
func (sub *Subscription) SetLabel(label string) { type SubscriptionOption interface {
sub.label = label IsSubscriptionOption()
} }
// GetID return the Nostr subscription ID as given to the relay, it will be a sequential number, stringified. // WithLabel puts a label on the subscription (it is prepended to the automatic id) that is sent to relays.
type WithLabel string
func (_ WithLabel) IsSubscriptionOption() {}
var _ SubscriptionOption = (WithLabel)("")
// GetID return the Nostr subscription ID as given to the Relay
// it is a concatenation of the label and a serial number.
func (sub *Subscription) GetID() string { func (sub *Subscription) GetID() string {
return sub.label + ":" + strconv.Itoa(sub.counter) return sub.label + ":" + strconv.Itoa(sub.counter)
} }
// Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01. // Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01.
// Unsub() also closes the channel sub.Events. // Unsub() also closes the channel sub.Events and makes a new one.
func (sub *Subscription) Unsub() { func (sub *Subscription) Unsub() {
sub.mutex.Lock()
defer sub.mutex.Unlock()
id := sub.GetID() id := sub.GetID()
if sub.Relay.IsConnected() {
closeMsg := CloseEnvelope(id) closeMsg := CloseEnvelope(id)
closeb, _ := (&closeMsg).MarshalJSON() closeb, _ := (&closeMsg).MarshalJSON()
debugLog("{%s} sending %v", sub.Relay.URL, closeb) debugLog("{%s} sending %v", sub.Relay.URL, closeb)
sub.conn.WriteMessage(closeb) sub.Relay.Write(closeb)
}
sub.live = false
sub.Relay.Subscriptions.Delete(id) sub.Relay.Subscriptions.Delete(id)
if !sub.stopped && sub.Events != nil { // do this so we don't have the possibility of closing the Events channel and then trying to send to it
close(sub.Events) sub.Relay.subscriptionChannelCloseQueue <- sub
}
sub.stopped = true
} }
// Sub sets sub.Filters and then calls sub.Fire(ctx). // Sub sets sub.Filters and then calls sub.Fire(ctx).
@ -70,28 +84,15 @@ func (sub *Subscription) Sub(ctx context.Context, filters Filters) {
// Fire sends the "REQ" command to the relay. // Fire sends the "REQ" command to the relay.
func (sub *Subscription) Fire() error { func (sub *Subscription) Fire() error {
id := sub.GetID() id := sub.GetID()
sub.Relay.Subscriptions.Store(id, sub)
reqb, _ := ReqEnvelope{id, sub.Filters}.MarshalJSON() reqb, _ := ReqEnvelope{id, sub.Filters}.MarshalJSON()
debugLog("{%s} sending %v", sub.Relay.URL, reqb) debugLog("{%s} sending %v", sub.Relay.URL, reqb)
if err := sub.conn.WriteMessage(reqb); err != nil {
sub.live = true
if err := sub.Relay.Write(reqb); err != nil {
sub.cancel() sub.cancel()
return fmt.Errorf("failed to write: %w", err) return fmt.Errorf("failed to write: %w", err)
} }
// the subscription ends once the context is canceled
go func() {
<-sub.Context.Done()
sub.Unsub()
}()
// or when the relay connection is closed
go func() {
<-sub.Relay.connectionContext.Done()
// cancel the context -- this will cause the other context cancelation cause above to be called
sub.cancel()
}()
return nil return nil
} }