diff --git a/relay.go b/relay.go new file mode 100644 index 0000000..361950d --- /dev/null +++ b/relay.go @@ -0,0 +1,184 @@ +package nostr + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "log" + "math/rand" + "time" + + s "github.com/SaveTheRbtz/generic-sync-map-go" + "github.com/gorilla/websocket" +) + +type Status int + +const ( + PublishStatusSent Status = 0 + PublishStatusFailed Status = -1 + PublishStatusSucceeded Status = 1 +) + +func (s Status) String() string { + switch s { + case PublishStatusSent: + return "sent" + case PublishStatusFailed: + return "failed" + case PublishStatusSucceeded: + return "success" + } + + return "unknown" +} + +type Relay struct { + URL string + + Connection *Connection + subscriptions s.MapOf[string, *Subscription] + + Notices chan string +} + +func NewRelay(url string) *Relay { + return &Relay{ + URL: NormalizeURL(url), + subscriptions: s.MapOf[string, *Subscription]{}, + } +} + +func (r *Relay) Connect() error { + if r.URL == "" { + return fmt.Errorf("invalid relay URL '%s'", r.URL) + } + + socket, _, err := websocket.DefaultDialer.Dial(r.URL, nil) + if err != nil { + return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) + } + + conn := NewConnection(socket) + + for { + typ, message, err := conn.socket.ReadMessage() + if err != nil { + return fmt.Errorf("read error: %w", err) + } + if typ == websocket.PingMessage { + conn.WriteMessage(websocket.PongMessage, nil) + continue + } + + if typ != websocket.TextMessage || len(message) == 0 || message[0] != '[' { + continue + } + + var jsonMessage []json.RawMessage + err = json.Unmarshal(message, &jsonMessage) + if err != nil { + continue + } + + if len(jsonMessage) < 2 { + continue + } + + var label string + json.Unmarshal(jsonMessage[0], &label) + + switch label { + case "NOTICE": + var content string + json.Unmarshal(jsonMessage[1], &content) + r.Notices <- content + case "EVENT": + if len(jsonMessage) < 3 { + continue + } + + var channel string + json.Unmarshal(jsonMessage[1], &channel) + if subscription, ok := r.subscriptions.Load(channel); ok { + var event Event + json.Unmarshal(jsonMessage[2], &event) + + // check signature of all received events, ignore invalid + ok, err := event.CheckSignature() + if !ok { + errmsg := "" + if err != nil { + errmsg = err.Error() + } + log.Printf("bad signature: %s", errmsg) + continue + } + + // check if the event matches the desired filter, ignore otherwise + if !subscription.filters.Match(&event) { + continue + } + + if !subscription.stopped { + subscription.Events <- event + } + } + } + } +} + +func (r Relay) Publish(event Event) chan Status { + statusChan := make(chan Status) + + go func() { + err := r.Connection.WriteJSON([]interface{}{"EVENT", event}) + if err != nil { + statusChan <- PublishStatusFailed + close(statusChan) + } + statusChan <- PublishStatusSent + + sub := r.Subscribe(Filters{Filter{IDs: []string{event.ID}}}) + for { + select { + case receivedEvent := <-sub.Events: + if receivedEvent.ID == event.ID { + statusChan <- PublishStatusSucceeded + close(statusChan) + break + } else { + continue + } + case <-time.After(5 * time.Second): + close(statusChan) + break + } + break + } + }() + + return statusChan +} + +func (r *Relay) Subscribe(filters Filters) *Subscription { + random := make([]byte, 7) + rand.Read(random) + id := hex.EncodeToString(random) + return r.subscribe(id, filters) +} + +func (r *Relay) subscribe(id string, filters Filters) *Subscription { + sub := Subscription{} + sub.id = id + + sub.Events = make(chan Event) + r.subscriptions.Store(sub.id, &sub) + + sub.Sub(filters) + return &sub +} + +func (r *Relay) Close() error { + return r.Connection.Close() +} diff --git a/relaypool.go b/relaypool.go index d248bac..99c3ec9 100644 --- a/relaypool.go +++ b/relaypool.go @@ -3,37 +3,12 @@ package nostr import ( "crypto/rand" "encoding/hex" - "encoding/json" "errors" "fmt" - "log" - "time" s "github.com/SaveTheRbtz/generic-sync-map-go" - "github.com/gorilla/websocket" ) -type Status int - -const ( - PublishStatusSent Status = 0 - PublishStatusFailed Status = -1 - PublishStatusSucceeded Status = 1 -) - -func (s Status) String() string { - switch s { - case PublishStatusSent: - return "sent" - case PublishStatusFailed: - return "failed" - case PublishStatusSucceeded: - return "success" - } - - return "unknown" -} - type PublishStatus struct { Relay string Status Status @@ -42,9 +17,10 @@ type PublishStatus struct { type RelayPool struct { SecretKey *string - Relays s.MapOf[string, RelayPoolPolicy] - websockets s.MapOf[string, *Connection] - subscriptions s.MapOf[string, *Subscription] + Policies s.MapOf[string, RelayPoolPolicy] + Relays s.MapOf[string, *Relay] + subscriptions s.MapOf[string, Filters] + eventStreams s.MapOf[string, chan EventMessage] Notices chan *NoticeMessage } @@ -75,9 +51,8 @@ type NoticeMessage struct { // New creates a new RelayPool with no relays in it func NewRelayPool() *RelayPool { return &RelayPool{ - Relays: s.MapOf[string, RelayPoolPolicy]{}, - websockets: s.MapOf[string, *Connection]{}, - subscriptions: s.MapOf[string, *Subscription]{}, + Policies: s.MapOf[string, RelayPoolPolicy]{}, + Relays: s.MapOf[string, *Relay]{}, Notices: make(chan *NoticeMessage), } @@ -90,101 +65,23 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error { policy = SimplePolicy{Read: true, Write: true} } - nm := NormalizeURL(url) - if nm == "" { - return fmt.Errorf("invalid relay URL '%s'", url) - } + relay := NewRelay(url) + r.Policies.Store(relay.URL, policy) + r.Relays.Store(relay.URL, relay) - socket, _, err := websocket.DefaultDialer.Dial(NormalizeURL(url), nil) - if err != nil { - return fmt.Errorf("error opening websocket to '%s': %w", nm, err) - } + r.subscriptions.Range(func(id string, filters Filters) bool { + sub := relay.subscribe(id, filters) + eventStream, _ := r.eventStreams.Load(id) - conn := NewConnection(socket) + go func(sub *Subscription) { + for evt := range sub.Events { + eventStream <- EventMessage{Relay: relay.URL, Event: evt} + } + }(sub) - r.Relays.Store(nm, policy) - r.websockets.Store(nm, conn) - - r.subscriptions.Range(func(_ string, sub *Subscription) bool { - sub.addRelay(nm, conn) return true }) - go func() { - for { - typ, message, err := conn.socket.ReadMessage() - if err != nil { - log.Println("read error: ", err) - return - } - if typ == websocket.PingMessage { - conn.WriteMessage(websocket.PongMessage, nil) - continue - } - - if typ != websocket.TextMessage || len(message) == 0 || message[0] != '[' { - continue - } - - var jsonMessage []json.RawMessage - err = json.Unmarshal(message, &jsonMessage) - if err != nil { - continue - } - - if len(jsonMessage) < 2 { - continue - } - - var label string - json.Unmarshal(jsonMessage[0], &label) - - switch label { - case "NOTICE": - var content string - json.Unmarshal(jsonMessage[1], &content) - r.Notices <- &NoticeMessage{ - Relay: nm, - Message: content, - } - case "EVENT": - if len(jsonMessage) < 3 { - continue - } - - var channel string - json.Unmarshal(jsonMessage[1], &channel) - if subscription, ok := r.subscriptions.Load(channel); ok { - var event Event - json.Unmarshal(jsonMessage[2], &event) - - // check signature of all received events, ignore invalid - ok, err := event.CheckSignature() - if !ok { - errmsg := "" - if err != nil { - errmsg = err.Error() - } - log.Printf("bad signature: %s", errmsg) - continue - } - - // check if the event matches the desired filter, ignore otherwise - if !subscription.filters.Match(&event) { - continue - } - - if !subscription.stopped { - subscription.Events <- EventMessage{ - Relay: nm, - Event: event, - } - } - } - } - } - }() - return nil } @@ -192,41 +89,36 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error { func (r *RelayPool) Remove(url string) { nm := NormalizeURL(url) - r.subscriptions.Range(func(_ string, sub *Subscription) bool { - sub.removeRelay(nm) - return true - }) - - if conn, ok := r.websockets.Load(nm); ok { - conn.Close() - } - r.Relays.Delete(nm) - r.websockets.Delete(nm) + r.Policies.Delete(nm) + + if relay, ok := r.Relays.Load(nm); ok { + relay.Close() + } } -func (r *RelayPool) Sub(filters Filters) *Subscription { +func (r *RelayPool) Sub(filters Filters) (string, chan EventMessage) { random := make([]byte, 7) rand.Read(random) + id := hex.EncodeToString(random) - subscription := Subscription{} - subscription.channel = hex.EncodeToString(random) - subscription.relays = s.MapOf[string, *Connection]{} + r.subscriptions.Store(id, filters) + eventStream := make(chan EventMessage) + r.eventStreams.Store(id, eventStream) - r.Relays.Range(func(relay string, policy RelayPoolPolicy) bool { - if policy.ShouldRead(filters) { - if ws, ok := r.websockets.Load(relay); ok { - subscription.relays.Store(relay, ws) + r.Relays.Range(func(_ string, relay *Relay) bool { + sub := relay.subscribe(id, filters) + + go func(sub *Subscription) { + for evt := range sub.Events { + eventStream <- EventMessage{Relay: relay.URL, Event: evt} } - } + }(sub) + return true }) - subscription.Events = make(chan EventMessage) - subscription.UniqueEvents = make(chan Event) - r.subscriptions.Store(subscription.channel, &subscription) - subscription.Sub(filters) - return &subscription + return id, eventStream } func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error) { @@ -251,35 +143,16 @@ func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error) } } - r.websockets.Range(func(relay string, conn *Connection) bool { - if r, ok := r.Relays.Load(relay); !ok || !r.ShouldWrite(evt) { + r.Relays.Range(func(url string, relay *Relay) bool { + if r, ok := r.Policies.Load(url); !ok || !r.ShouldWrite(evt) { return true } - go func(relay string, conn *Connection) { - err := conn.WriteJSON([]interface{}{"EVENT", evt}) - if err != nil { - log.Printf("error sending event to '%s': %s", relay, err.Error()) - status <- PublishStatus{relay, PublishStatusFailed} + go func(relay *Relay) { + for resultStatus := range relay.Publish(*evt) { + status <- PublishStatus{relay.URL, resultStatus} } - status <- PublishStatus{relay, PublishStatusSent} - - subscription := r.Sub(Filters{Filter{IDs: []string{evt.ID}}}) - for { - select { - case event := <-subscription.UniqueEvents: - if event.ID == evt.ID { - status <- PublishStatus{relay, PublishStatusSucceeded} - break - } else { - continue - } - case <-time.After(5 * time.Second): - break - } - break - } - }(relay, conn) + }(relay) return true }) diff --git a/subscription.go b/subscription.go index 07000c8..bd0eee5 100644 --- a/subscription.go +++ b/subscription.go @@ -1,18 +1,11 @@ package nostr -import ( - s "github.com/SaveTheRbtz/generic-sync-map-go" -) - type Subscription struct { - channel string - relays s.MapOf[string, *Connection] + id string + conn *Connection filters Filters - Events chan EventMessage - - started bool - UniqueEvents chan Event + Events chan Event stopped bool } @@ -22,78 +15,22 @@ type EventMessage struct { Relay string } -func (subscription Subscription) Unsub() { - subscription.relays.Range(func(_ string, conn *Connection) bool { - conn.WriteJSON([]interface{}{ - "CLOSE", - subscription.channel, - }) - return true - }) +func (sub Subscription) Unsub() { + sub.conn.WriteJSON([]interface{}{"CLOSE", sub.id}) - subscription.stopped = true - if subscription.Events != nil { - close(subscription.Events) - } - if subscription.UniqueEvents != nil { - close(subscription.UniqueEvents) + sub.stopped = true + if sub.Events != nil { + close(sub.Events) } } -func (subscription *Subscription) Sub(filters Filters) { - subscription.filters = filters +func (sub *Subscription) Sub(filters Filters) { + sub.filters = filters - subscription.relays.Range(func(_ string, conn *Connection) bool { - message := []interface{}{ - "REQ", - subscription.channel, - } - for _, filter := range subscription.filters { - message = append(message, filter) - } - - conn.WriteJSON(message) - return true - }) - - if !subscription.started { - go subscription.startHandlingUnique() - } -} - -func (subscription Subscription) startHandlingUnique() { - seen := make(map[string]struct{}) - for em := range subscription.Events { - if _, ok := seen[em.Event.ID]; ok { - continue - } - seen[em.Event.ID] = struct{}{} - if !subscription.stopped { - subscription.UniqueEvents <- em.Event - } - } -} - -func (subscription Subscription) removeRelay(relay string) { - if conn, ok := subscription.relays.Load(relay); ok { - subscription.relays.Delete(relay) - conn.WriteJSON([]interface{}{ - "CLOSE", - subscription.channel, - }) - } -} - -func (subscription Subscription) addRelay(relay string, conn *Connection) { - subscription.relays.Store(relay, conn) - - message := []interface{}{ - "REQ", - subscription.channel, - } - for _, filter := range subscription.filters { + message := []interface{}{"REQ", sub.id} + for _, filter := range sub.filters { message = append(message, filter) } - conn.WriteJSON(message) + sub.conn.WriteJSON(message) }