mirror of
https://github.com/nbd-wtf/go-nostr.git
synced 2025-07-01 11:02:47 +02:00
get rid of mutexes and use a single loop to prevent races.
in the meantime change the API to makes a little less error-prone.
This commit is contained in:
356
relay.go
356
relay.go
@ -3,11 +3,14 @@ package nostr
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/puzpuzpuz/xsync"
|
||||
)
|
||||
|
||||
@ -41,41 +44,104 @@ type Relay struct {
|
||||
Connection *Connection
|
||||
Subscriptions *xsync.MapOf[string, *Subscription]
|
||||
|
||||
Challenges chan string // NIP-42 Challenges
|
||||
Notices chan string
|
||||
ConnectionError error
|
||||
connectionContext context.Context // will be canceled when the connection closes
|
||||
connectionContextCancel context.CancelFunc
|
||||
|
||||
okCallbacks *xsync.MapOf[string, func(bool, string)]
|
||||
mutex sync.RWMutex
|
||||
challenges chan string // NIP-42 challenges
|
||||
notices chan string // NIP-01 NOTICEs
|
||||
okCallbacks *xsync.MapOf[string, func(bool, string)]
|
||||
writeQueue chan writeRequest
|
||||
subscriptionChannelCloseQueue chan *Subscription
|
||||
|
||||
// custom things that aren't often used
|
||||
//
|
||||
AssumeValid bool // this will skip verifying signatures for events received from this relay
|
||||
}
|
||||
|
||||
type writeRequest struct {
|
||||
msg []byte
|
||||
answer chan error
|
||||
}
|
||||
|
||||
// NewRelay returns a new relay. The relay connection will be closed when the context is canceled.
|
||||
func NewRelay(ctx context.Context, url string) *Relay {
|
||||
func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Relay {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Relay{
|
||||
r := &Relay{
|
||||
URL: NormalizeURL(url),
|
||||
connectionContext: ctx,
|
||||
connectionContextCancel: cancel,
|
||||
Subscriptions: xsync.NewMapOf[*Subscription](),
|
||||
okCallbacks: xsync.NewMapOf[func(bool, string)](),
|
||||
writeQueue: make(chan writeRequest),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
switch o := opt.(type) {
|
||||
case WithNoticeHandler:
|
||||
r.notices = make(chan string)
|
||||
go func() {
|
||||
for notice := range r.notices {
|
||||
o(notice)
|
||||
}
|
||||
}()
|
||||
case WithAuthHandler:
|
||||
r.challenges = make(chan string)
|
||||
go func() {
|
||||
for challenge := range r.challenges {
|
||||
authEvent := Event{
|
||||
CreatedAt: Now(),
|
||||
Kind: 22242,
|
||||
Tags: Tags{
|
||||
Tag{"relay", url},
|
||||
Tag{"challenge", challenge},
|
||||
},
|
||||
Content: "",
|
||||
}
|
||||
if ok := o(r.connectionContext, &authEvent); ok {
|
||||
r.Auth(r.connectionContext, authEvent)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// RelayConnect returns a relay object connected to url.
|
||||
// Once successfully connected, cancelling ctx has no effect.
|
||||
// To close the connection, call r.Close().
|
||||
func RelayConnect(ctx context.Context, url string) (*Relay, error) {
|
||||
r := NewRelay(context.Background(), url)
|
||||
func RelayConnect(ctx context.Context, url string, opts ...RelayOption) (*Relay, error) {
|
||||
r := NewRelay(context.Background(), url, opts...)
|
||||
err := r.Connect(ctx)
|
||||
return r, err
|
||||
}
|
||||
|
||||
// When instantiating relay connections, some options may be passed.
|
||||
// RelayOption is the type of the argument passed for that.
|
||||
// Some examples of this are WithNoticeHandler and WithAuthHandler.
|
||||
type RelayOption interface {
|
||||
IsRelayOption()
|
||||
}
|
||||
|
||||
// WithNoticeHandler just takes notices and is expected to do something with them.
|
||||
// when not given, defaults to logging the notices.
|
||||
type WithNoticeHandler func(notice string)
|
||||
|
||||
func (_ WithNoticeHandler) IsRelayOption() {}
|
||||
|
||||
var _ RelayOption = (WithNoticeHandler)(nil)
|
||||
|
||||
// WithAuthHandler takes an auth event and expects it to be signed.
|
||||
// when not given, AUTH messages from relays are ignored.
|
||||
type WithAuthHandler func(ctx context.Context, authEvent *Event) (ok bool)
|
||||
|
||||
func (_ WithAuthHandler) IsRelayOption() {}
|
||||
|
||||
var _ RelayOption = (WithAuthHandler)(nil)
|
||||
|
||||
// String() just prints the relay URL.
|
||||
func (r *Relay) String() string {
|
||||
return r.URL
|
||||
}
|
||||
@ -83,6 +149,9 @@ func (r *Relay) String() string {
|
||||
// Context retrieves the context that is associated with this relay connection.
|
||||
func (r *Relay) Context() context.Context { return r.connectionContext }
|
||||
|
||||
// IsConnected returns true if the connection to this relay seems to be active.
|
||||
func (r *Relay) IsConnected() bool { return r.connectionContext.Err() == nil }
|
||||
|
||||
// Connect tries to establish a websocket connection to r.URL.
|
||||
// If the context expires before the connection is complete, an error is returned.
|
||||
// Once successfully connected, context expiration has no effect: call r.Close
|
||||
@ -113,36 +182,57 @@ func (r *Relay) Connect(ctx context.Context) error {
|
||||
}
|
||||
r.Connection = conn
|
||||
|
||||
r.Challenges = make(chan string)
|
||||
r.Notices = make(chan string)
|
||||
|
||||
// close these channels when the connection is dropped
|
||||
go func() {
|
||||
<-r.connectionContext.Done()
|
||||
r.mutex.Lock()
|
||||
close(r.Challenges)
|
||||
close(r.Notices)
|
||||
r.mutex.Unlock()
|
||||
}()
|
||||
|
||||
// ping every 29 seconds
|
||||
ticker := time.NewTicker(29 * time.Second)
|
||||
|
||||
// queue all messages received from the relay on this
|
||||
messageHandler := make(chan Envelope)
|
||||
|
||||
// we'll queue all relay actions (handling received messages etc) in a single queue
|
||||
// such that we can close channels safely without mutex spaghetti
|
||||
go func() {
|
||||
ticker := time.NewTicker(29 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-r.connectionContext.Done():
|
||||
// close these things when the connection is closed
|
||||
if r.challenges != nil {
|
||||
close(r.challenges)
|
||||
}
|
||||
if r.notices != nil {
|
||||
close(r.notices)
|
||||
}
|
||||
// stop the ticker
|
||||
ticker.Stop()
|
||||
// close all subscriptions
|
||||
r.Subscriptions.Range(func(_ string, sub *Subscription) bool {
|
||||
sub.Unsub()
|
||||
return true
|
||||
})
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := conn.Ping()
|
||||
err := wsutil.WriteClientMessage(r.Connection.conn, ws.OpPing, nil)
|
||||
if err != nil {
|
||||
InfoLogger.Printf("{%s} error writing ping: %v; closing websocket", r.URL, err)
|
||||
r.Close()
|
||||
r.Close() // this should trigger a context cancelation
|
||||
return
|
||||
}
|
||||
case envelope := <-messageHandler:
|
||||
// this will run synchronously in this goroutine
|
||||
r.HandleRelayMessage(envelope)
|
||||
case writeRequest := <-r.writeQueue:
|
||||
// all write requests will go through this to prevent races
|
||||
if err := r.Connection.WriteMessage(writeRequest.msg); err != nil {
|
||||
writeRequest.answer <- err
|
||||
}
|
||||
close(writeRequest.answer)
|
||||
case toClose := <-r.subscriptionChannelCloseQueue:
|
||||
// every time a subscription ends we use this queue to close its Events channel
|
||||
close(toClose.Events)
|
||||
toClose.Events = make(chan *Event)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// handling received messages
|
||||
go func() {
|
||||
for {
|
||||
message, err := conn.ReadMessage(r.connectionContext)
|
||||
@ -151,91 +241,88 @@ func (r *Relay) Connect(ctx context.Context) error {
|
||||
break
|
||||
}
|
||||
|
||||
debugLog("{%s} %v\n", r.URL, message)
|
||||
|
||||
envelope := ParseMessage(message)
|
||||
if envelope == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch env := envelope.(type) {
|
||||
case *NoticeEnvelope:
|
||||
debugLog("{%s} %v\n", r.URL, message)
|
||||
// TODO: improve this, otherwise if the application doesn't read the notices
|
||||
// we'll consume ever more memory with each new notice
|
||||
go func() {
|
||||
r.mutex.RLock()
|
||||
if r.connectionContext.Err() == nil {
|
||||
r.Notices <- string(*env)
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
}()
|
||||
case *AuthEnvelope:
|
||||
debugLog("{%s} %v\n", r.URL, message)
|
||||
if env.Challenge == nil {
|
||||
continue
|
||||
}
|
||||
// TODO: same as with NoticeEnvelope
|
||||
go func() {
|
||||
r.mutex.RLock()
|
||||
if r.connectionContext.Err() == nil {
|
||||
r.Challenges <- *env.Challenge
|
||||
}
|
||||
r.mutex.RUnlock()
|
||||
}()
|
||||
case *EventEnvelope:
|
||||
debugLog("{%s} %v\n", r.URL, message)
|
||||
if env.SubscriptionID == nil {
|
||||
continue
|
||||
}
|
||||
if subscription, ok := r.Subscriptions.Load(*env.SubscriptionID); !ok {
|
||||
// InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID)
|
||||
continue
|
||||
} else {
|
||||
func() {
|
||||
// check if the event matches the desired filter, ignore otherwise
|
||||
if !subscription.Filters.Match(&env.Event) {
|
||||
InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters, env.Event)
|
||||
return
|
||||
}
|
||||
|
||||
subscription.mutex.Lock()
|
||||
defer subscription.mutex.Unlock()
|
||||
if subscription.stopped {
|
||||
return
|
||||
}
|
||||
|
||||
// check signature, ignore invalid, except from trusted (AssumeValid) relays
|
||||
if !r.AssumeValid {
|
||||
if ok, err := env.Event.CheckSignature(); !ok {
|
||||
errmsg := ""
|
||||
if err != nil {
|
||||
errmsg = err.Error()
|
||||
}
|
||||
InfoLogger.Printf("{%s} bad signature: %s\n", r.URL, errmsg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
subscription.Events <- &env.Event
|
||||
}()
|
||||
}
|
||||
case *EOSEEnvelope:
|
||||
debugLog("{%s} %v\n", r.URL, message)
|
||||
if subscription, ok := r.Subscriptions.Load(string(*env)); ok {
|
||||
subscription.emitEose.Do(func() {
|
||||
subscription.EndOfStoredEvents <- struct{}{}
|
||||
})
|
||||
}
|
||||
case *OKEnvelope:
|
||||
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
|
||||
okCallback(env.OK, *env.Reason)
|
||||
}
|
||||
}
|
||||
messageHandler <- envelope
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleRelayMessage handles a message received from a relay.
|
||||
func (r *Relay) HandleRelayMessage(envelope Envelope) {
|
||||
switch env := envelope.(type) {
|
||||
case *NoticeEnvelope:
|
||||
// see WithNoticeHandler
|
||||
if r.notices != nil {
|
||||
r.notices <- string(*env)
|
||||
} else {
|
||||
log.Printf("NOTICE from %s: '%s'\n", r.URL, string(*env))
|
||||
}
|
||||
case *AuthEnvelope:
|
||||
if env.Challenge == nil {
|
||||
return
|
||||
}
|
||||
// see WithAuthHandler
|
||||
if r.challenges != nil {
|
||||
r.challenges <- *env.Challenge
|
||||
}
|
||||
case *EventEnvelope:
|
||||
if env.SubscriptionID == nil {
|
||||
return
|
||||
}
|
||||
if subscription, ok := r.Subscriptions.Load(*env.SubscriptionID); !ok {
|
||||
// InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID)
|
||||
return
|
||||
} else {
|
||||
// check if the event matches the desired filter, ignore otherwise
|
||||
if !subscription.Filters.Match(&env.Event) {
|
||||
InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters, env.Event)
|
||||
return
|
||||
}
|
||||
|
||||
// check signature, ignore invalid, except from trusted (AssumeValid) relays
|
||||
if !r.AssumeValid {
|
||||
if ok, err := env.Event.CheckSignature(); !ok {
|
||||
errmsg := ""
|
||||
if err != nil {
|
||||
errmsg = err.Error()
|
||||
}
|
||||
InfoLogger.Printf("{%s} bad signature: %s\n", r.URL, errmsg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if subscription.live {
|
||||
subscription.Events <- &env.Event
|
||||
}
|
||||
}
|
||||
case *EOSEEnvelope:
|
||||
if subscription, ok := r.Subscriptions.Load(string(*env)); ok {
|
||||
subscription.emitEose.Do(func() {
|
||||
close(subscription.EndOfStoredEvents)
|
||||
})
|
||||
}
|
||||
case *OKEnvelope:
|
||||
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
|
||||
okCallback(env.OK, *env.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write queues a message to be sent to the relay.
|
||||
func (r *Relay) Write(msg []byte) error {
|
||||
ch := make(chan error)
|
||||
r.writeQueue <- writeRequest{msg: msg, answer: ch}
|
||||
return <-ch
|
||||
}
|
||||
|
||||
// Publish sends an "EVENT" command to the relay r as in NIP-01.
|
||||
// Status can be: success, failed, or sent (no response from relay before ctx times out).
|
||||
func (r *Relay) Publish(ctx context.Context, event Event) (Status, error) {
|
||||
@ -276,7 +363,7 @@ func (r *Relay) Publish(ctx context.Context, event Event) (Status, error) {
|
||||
envb, _ := EventEnvelope{Event: event}.MarshalJSON()
|
||||
debugLog("{%s} sending %v\n", r.URL, envb)
|
||||
status = PublishStatusSent
|
||||
if err := r.Connection.WriteMessage(envb); err != nil {
|
||||
if err := r.Write(envb); err != nil {
|
||||
status = PublishStatusFailed
|
||||
return status, err
|
||||
}
|
||||
@ -335,7 +422,7 @@ func (r *Relay) Auth(ctx context.Context, event Event) (Status, error) {
|
||||
// send AUTH
|
||||
authResponse, _ := AuthEnvelope{Event: event}.MarshalJSON()
|
||||
debugLog("{%s} sending %v\n", r.URL, authResponse)
|
||||
if err := r.Connection.WriteMessage(authResponse); err != nil {
|
||||
if err := r.Write(authResponse); err != nil {
|
||||
// status will be "failed"
|
||||
return status, err
|
||||
}
|
||||
@ -356,13 +443,8 @@ func (r *Relay) Auth(ctx context.Context, event Event) (Status, error) {
|
||||
// Subscribe sends a "REQ" command to the relay r as in NIP-01.
|
||||
// Events are returned through the channel sub.Events.
|
||||
// The subscription is closed when context ctx is cancelled ("CLOSE" in NIP-01).
|
||||
func (r *Relay) Subscribe(ctx context.Context, filters Filters) (*Subscription, error) {
|
||||
if r.Connection == nil {
|
||||
panic(fmt.Errorf("must call .Connect() first before calling .Subscribe()"))
|
||||
}
|
||||
|
||||
sub := r.PrepareSubscription(ctx)
|
||||
sub.Filters = filters
|
||||
func (r *Relay) Subscribe(ctx context.Context, filters Filters, opts ...SubscriptionOption) (*Subscription, error) {
|
||||
sub := r.PrepareSubscription(ctx, filters, opts...)
|
||||
|
||||
if err := sub.Fire(); err != nil {
|
||||
return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filters, r.URL, err)
|
||||
@ -371,8 +453,47 @@ func (r *Relay) Subscribe(ctx context.Context, filters Filters) (*Subscription,
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error) {
|
||||
sub, err := r.Subscribe(ctx, Filters{filter})
|
||||
// PrepareSubscription creates a subscription, but doesn't fire it.
|
||||
func (r *Relay) PrepareSubscription(ctx context.Context, filters Filters, opts ...SubscriptionOption) *Subscription {
|
||||
if r.Connection == nil {
|
||||
panic(fmt.Errorf("must call .Connect() first before calling .Subscribe()"))
|
||||
}
|
||||
|
||||
current := subscriptionIdCounter.Add(1)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
sub := &Subscription{
|
||||
Relay: r,
|
||||
Context: ctx,
|
||||
cancel: cancel,
|
||||
conn: r.Connection,
|
||||
counter: int(current),
|
||||
Events: make(chan *Event),
|
||||
EndOfStoredEvents: make(chan struct{}),
|
||||
Filters: filters,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
switch o := opt.(type) {
|
||||
case WithLabel:
|
||||
sub.label = string(o)
|
||||
}
|
||||
}
|
||||
|
||||
id := sub.GetID()
|
||||
r.Subscriptions.Store(id, sub)
|
||||
|
||||
// the subscription ends once the context is canceled
|
||||
go func() {
|
||||
<-sub.Context.Done()
|
||||
sub.Unsub()
|
||||
}()
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
func (r *Relay) QuerySync(ctx context.Context, filter Filter, opts ...SubscriptionOption) ([]*Event, error) {
|
||||
sub, err := r.Subscribe(ctx, Filters{filter}, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -403,21 +524,6 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) PrepareSubscription(ctx context.Context) *Subscription {
|
||||
current := subscriptionIdCounter.Add(1)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &Subscription{
|
||||
Relay: r,
|
||||
Context: ctx,
|
||||
cancel: cancel,
|
||||
conn: r.Connection,
|
||||
counter: int(current),
|
||||
Events: make(chan *Event),
|
||||
EndOfStoredEvents: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) Close() error {
|
||||
if r.connectionContextCancel == nil {
|
||||
return fmt.Errorf("relay not connected")
|
||||
|
Reference in New Issue
Block a user