diff --git a/relaypool/publishstatus.go b/relaypool/publishstatus.go new file mode 100644 index 0000000..fe66436 --- /dev/null +++ b/relaypool/publishstatus.go @@ -0,0 +1,12 @@ +package relaypool + +const ( + PublishStatusSent = 0 + PublishStatusFailed = -1 + PublishStatusSucceeded = 1 +) + +type PublishStatus struct { + Relay string + Status int +} diff --git a/relaypool/relaypool.go b/relaypool/relaypool.go index 0f1afe8..a850afd 100644 --- a/relaypool/relaypool.go +++ b/relaypool/relaypool.go @@ -1,12 +1,16 @@ package relaypool import ( + "crypto/rand" + "encoding/hex" "encoding/json" "errors" "fmt" "log" + "time" "github.com/fiatjaf/go-nostr/event" + "github.com/fiatjaf/go-nostr/filter" nostrutils "github.com/fiatjaf/go-nostr/utils" "github.com/gorilla/websocket" ) @@ -14,14 +18,11 @@ import ( type RelayPool struct { SecretKey *string - Relays map[string]Policy - websockets map[string]*websocket.Conn + Relays map[string]Policy + websockets map[string]*websocket.Conn + subscriptions map[string]*Subscription - Events chan *EventMessage Notices chan *NoticeMessage - - SubscribedKeys []string - SubscribedEvents []string } type Policy struct { @@ -34,31 +35,6 @@ type SimplePolicy struct { Write bool } -type EventMessage struct { - Event event.Event - Context byte - Relay string -} - -func (em *EventMessage) UnmarshalJSON(b []byte) error { - var temp []json.RawMessage - if err := json.Unmarshal(b, &temp); err != nil { - return err - } - if len(temp) < 2 { - return errors.New("message is not an array of 2 or more") - } - if err := json.Unmarshal(temp[0], &em.Event); err != nil { - return err - } - var context string - if err := json.Unmarshal(temp[1], &context); err != nil { - return err - } - em.Context = context[0] - return nil -} - type NoticeMessage struct { Message string Relay string @@ -92,7 +68,6 @@ func New() *RelayPool { Relays: make(map[string]Policy), websockets: make(map[string]*websocket.Conn), - Events: make(chan *EventMessage), Notices: make(chan *NoticeMessage), } } @@ -109,7 +84,7 @@ func (r *RelayPool) Add(url string, policy *Policy) error { return fmt.Errorf("invalid relay URL '%s'", url) } - conn, _, err := websocket.DefaultDialer.Dial(nostrutils.WebsocketURL(url), nil) + conn, _, err := websocket.DefaultDialer.Dial(nostrutils.NormalizeURL(url), nil) if err != nil { return fmt.Errorf("error opening websocket to '%s': %w", nm, err) } @@ -117,6 +92,10 @@ func (r *RelayPool) Add(url string, policy *Policy) error { r.Relays[nm] = *policy r.websockets[nm] = conn + for _, sub := range r.subscriptions { + sub.addRelay(nm, conn) + } + go func() { for { typ, message, err := conn.ReadMessage() @@ -132,21 +111,49 @@ func (r *RelayPool) Add(url string, policy *Policy) error { continue } - var noticeMessage NoticeMessage - var eventMessage EventMessage - err = json.Unmarshal(message, &eventMessage) - if err == nil { - eventMessage.Relay = nm - r.Events <- &eventMessage - } else { - err = json.Unmarshal(message, ¬iceMessage) - if err == nil { - noticeMessage.Relay = nm - r.Notices <- ¬iceMessage - } else { + var jsonMessage []json.RawMessage + err = json.Unmarshal(message, &jsonMessage) + if err != nil { + continue + } + + if len(jsonMessage) < 2 { + continue + } + + var label string + json.Unmarshal(jsonMessage[0], &label) + + switch label { + case "NOTICE": + var content string + json.Unmarshal(jsonMessage[1], &content) + r.Notices <- &NoticeMessage{ + Relay: nm, + Message: content, + } + case "EVENT": + if len(jsonMessage) < 3 { continue } + + var channel string + json.Unmarshal(jsonMessage[1], &channel) + if subscription, ok := r.subscriptions[channel]; ok { + var event event.Event + json.Unmarshal(jsonMessage[2], &event) + ok, _ := event.CheckSignature() + if !ok { + continue + } + + subscription.Events <- EventMessage{ + Relay: nm, + Event: event, + } + } } + } }() @@ -156,101 +163,68 @@ func (r *RelayPool) Add(url string, policy *Policy) error { // Remove removes a relay from the pool. func (r *RelayPool) Remove(url string) { nm := nostrutils.NormalizeURL(url) + + for _, sub := range r.subscriptions { + sub.removeRelay(nm) + } if conn, ok := r.websockets[nm]; ok { conn.Close() } + delete(r.Relays, nm) delete(r.websockets, nm) } -func (r *RelayPool) SubKey(key string) { - for _, conn := range r.websockets { - conn.WriteMessage(websocket.TextMessage, []byte("sub-key:"+key)) - } -} +func (r *RelayPool) Sub(filter filter.EventFilter) *Subscription { + random := make([]byte, 7) + rand.Read(random) -func (r *RelayPool) UnsubKey(key string) { - for _, conn := range r.websockets { - conn.WriteMessage(websocket.TextMessage, []byte("unsub-key:"+key)) - } -} - -func (r *RelayPool) SubEvent(id string) { - for _, conn := range r.websockets { - conn.WriteMessage(websocket.TextMessage, []byte("sub-event:"+id)) - } -} - -func (r *RelayPool) ReqFeed(opts map[string]interface{}) { - var sopts string - if opts == nil { - sopts = "{}" - } else { - jopts, _ := json.Marshal(opts) - sopts = string(jopts) - } - - for r, conn := range r.websockets { - err := conn.WriteMessage(websocket.TextMessage, []byte("req-feed:"+sopts)) - if err != nil { - log.Printf("error sending req-feed to '%s': %s", r, err.Error()) + subscription := Subscription{} + subscription.channel = hex.EncodeToString(random) + subscription.relays = make(map[string]*websocket.Conn) + for relay, policy := range r.Relays { + if policy.Read { + ws := r.websockets[relay] + subscription.relays[relay] = ws } } + subscription.Events = make(chan EventMessage) + r.subscriptions[subscription.channel] = &subscription + + subscription.Sub(&filter) + return &subscription } -func (r *RelayPool) ReqEvent(id string, opts map[string]interface{}) { - if opts == nil { - opts = make(map[string]interface{}) - } - opts["id"] = id +func (r *RelayPool) PublishEvent(evt *event.Event) (*event.Event, chan PublishStatus, error) { + status := make(chan PublishStatus) - jopts, _ := json.Marshal(opts) - sopts := string(jopts) - - for r, conn := range r.websockets { - err := conn.WriteMessage(websocket.TextMessage, []byte("req-event:"+sopts)) - if err != nil { - log.Printf("error sending req-event to '%s': %s", r, err.Error()) - } - } -} - -func (r *RelayPool) ReqKey(key string, opts map[string]interface{}) { - if opts == nil { - opts = make(map[string]interface{}) - } - opts["key"] = key - - jopts, _ := json.Marshal(opts) - sopts := string(jopts) - - for r, conn := range r.websockets { - err := conn.WriteMessage(websocket.TextMessage, []byte("req-key:"+sopts)) - if err != nil { - log.Printf("error sending req-key to '%s': %s", r, err.Error()) - } - } -} - -func (r *RelayPool) PublishEvent(evt *event.Event) (*event.Event, error) { if r.SecretKey == nil && evt.Sig == "" { - return nil, errors.New("PublishEvent needs either a signed event to publish or to have been configured with a .SecretKey.") + return nil, status, errors.New("PublishEvent needs either a signed event to publish or to have been configured with a .SecretKey.") } if evt.Sig == "" { err := evt.Sign(*r.SecretKey) if err != nil { - return nil, fmt.Errorf("Error signing event: %w", err) + return nil, status, fmt.Errorf("Error signing event: %w", err) } } jevt, _ := json.Marshal(evt) - for r, conn := range r.websockets { - err := conn.WriteMessage(websocket.TextMessage, jevt) - if err != nil { - log.Printf("error sending event to '%s': %s", r, err.Error()) - } + for relay, conn := range r.websockets { + go func(relay string, conn *websocket.Conn) { + err := conn.WriteJSON([]interface{}{"EVENT", jevt}) + if err != nil { + log.Printf("error sending event to '%s': %s", relay, err.Error()) + status <- PublishStatus{relay, PublishStatusFailed} + } + status <- PublishStatus{relay, PublishStatusSent} + + subscription := r.Sub(filter.EventFilter{ID: evt.ID}) + + time.Sleep(5 * time.Second) + subscription.Unsub() + }(relay, conn) } - return evt, nil + return evt, status, nil } diff --git a/relaypool/subscription.go b/relaypool/subscription.go new file mode 100644 index 0000000..aa06487 --- /dev/null +++ b/relaypool/subscription.go @@ -0,0 +1,62 @@ +package relaypool + +import ( + "github.com/fiatjaf/go-nostr/event" + "github.com/fiatjaf/go-nostr/filter" + "github.com/gorilla/websocket" +) + +type Subscription struct { + channel string + relays map[string]*websocket.Conn + + filter *filter.EventFilter + Events chan EventMessage +} + +func (subscription Subscription) Unsub() { + for _, ws := range subscription.relays { + ws.WriteJSON([]interface{}{ + "CLOSE", + subscription.channel, + }) + } +} + +func (subscription Subscription) Sub(filter *filter.EventFilter) { + if filter != nil { + subscription.filter = filter + } + + for _, ws := range subscription.relays { + ws.WriteJSON([]interface{}{ + "REQ", + subscription.channel, + subscription.filter, + }) + } +} + +func (subscription Subscription) removeRelay(relay string) { + if ws, ok := subscription.relays[relay]; ok { + delete(subscription.relays, relay) + ws.WriteJSON([]interface{}{ + "CLOSE", + subscription.channel, + }) + } +} + +func (subscription Subscription) addRelay(relay string, ws *websocket.Conn) { + subscription.relays[relay] = ws + ws.WriteJSON([]interface{}{ + "REQ", + subscription.channel, + subscription.filter, + }) +} + +type EventMessage struct { + Event event.Event + Relay string +} diff --git a/utils/utils.go b/utils/utils.go index cace5a3..47f63d3 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -7,26 +7,22 @@ import ( func NormalizeURL(u string) string { if !strings.HasPrefix(u, "http") && !strings.HasPrefix(u, "ws") { - u = "ws://" + u + u = "wss://" + u } p, err := url.Parse(u) if err != nil { return "" } + if p.Scheme == "http" { + p.Scheme = "ws" + } else if p.Scheme == "https" { + p.Scheme = "wss" + } + if strings.HasSuffix(p.RawPath, "/") { p.RawPath = p.RawPath[0 : len(p.RawPath)-1] } - if strings.HasSuffix(p.RawPath, "/ws") { - p.RawPath = p.RawPath[0 : len(p.RawPath)-3] - } - - return p.String() -} - -func WebsocketURL(u string) string { - p, _ := url.Parse(NormalizeURL(u)) - p.RawPath += "/ws" return p.String() }