mirror of
https://github.com/nbd-wtf/go-nostr.git
synced 2025-09-18 11:32:25 +02:00
change relaypool and subscription such that a Relay can have an independent existence.
This commit is contained in:
184
relay.go
Normal file
184
relay.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
s "github.com/SaveTheRbtz/generic-sync-map-go"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Status int
|
||||
|
||||
const (
|
||||
PublishStatusSent Status = 0
|
||||
PublishStatusFailed Status = -1
|
||||
PublishStatusSucceeded Status = 1
|
||||
)
|
||||
|
||||
func (s Status) String() string {
|
||||
switch s {
|
||||
case PublishStatusSent:
|
||||
return "sent"
|
||||
case PublishStatusFailed:
|
||||
return "failed"
|
||||
case PublishStatusSucceeded:
|
||||
return "success"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
type Relay struct {
|
||||
URL string
|
||||
|
||||
Connection *Connection
|
||||
subscriptions s.MapOf[string, *Subscription]
|
||||
|
||||
Notices chan string
|
||||
}
|
||||
|
||||
func NewRelay(url string) *Relay {
|
||||
return &Relay{
|
||||
URL: NormalizeURL(url),
|
||||
subscriptions: s.MapOf[string, *Subscription]{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) Connect() error {
|
||||
if r.URL == "" {
|
||||
return fmt.Errorf("invalid relay URL '%s'", r.URL)
|
||||
}
|
||||
|
||||
socket, _, err := websocket.DefaultDialer.Dial(r.URL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
|
||||
}
|
||||
|
||||
conn := NewConnection(socket)
|
||||
|
||||
for {
|
||||
typ, message, err := conn.socket.ReadMessage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read error: %w", err)
|
||||
}
|
||||
if typ == websocket.PingMessage {
|
||||
conn.WriteMessage(websocket.PongMessage, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
if typ != websocket.TextMessage || len(message) == 0 || message[0] != '[' {
|
||||
continue
|
||||
}
|
||||
|
||||
var jsonMessage []json.RawMessage
|
||||
err = json.Unmarshal(message, &jsonMessage)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(jsonMessage) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
var label string
|
||||
json.Unmarshal(jsonMessage[0], &label)
|
||||
|
||||
switch label {
|
||||
case "NOTICE":
|
||||
var content string
|
||||
json.Unmarshal(jsonMessage[1], &content)
|
||||
r.Notices <- content
|
||||
case "EVENT":
|
||||
if len(jsonMessage) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
var channel string
|
||||
json.Unmarshal(jsonMessage[1], &channel)
|
||||
if subscription, ok := r.subscriptions.Load(channel); ok {
|
||||
var event Event
|
||||
json.Unmarshal(jsonMessage[2], &event)
|
||||
|
||||
// check signature of all received events, ignore invalid
|
||||
ok, err := event.CheckSignature()
|
||||
if !ok {
|
||||
errmsg := ""
|
||||
if err != nil {
|
||||
errmsg = err.Error()
|
||||
}
|
||||
log.Printf("bad signature: %s", errmsg)
|
||||
continue
|
||||
}
|
||||
|
||||
// check if the event matches the desired filter, ignore otherwise
|
||||
if !subscription.filters.Match(&event) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !subscription.stopped {
|
||||
subscription.Events <- event
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r Relay) Publish(event Event) chan Status {
|
||||
statusChan := make(chan Status)
|
||||
|
||||
go func() {
|
||||
err := r.Connection.WriteJSON([]interface{}{"EVENT", event})
|
||||
if err != nil {
|
||||
statusChan <- PublishStatusFailed
|
||||
close(statusChan)
|
||||
}
|
||||
statusChan <- PublishStatusSent
|
||||
|
||||
sub := r.Subscribe(Filters{Filter{IDs: []string{event.ID}}})
|
||||
for {
|
||||
select {
|
||||
case receivedEvent := <-sub.Events:
|
||||
if receivedEvent.ID == event.ID {
|
||||
statusChan <- PublishStatusSucceeded
|
||||
close(statusChan)
|
||||
break
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
close(statusChan)
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
}()
|
||||
|
||||
return statusChan
|
||||
}
|
||||
|
||||
func (r *Relay) Subscribe(filters Filters) *Subscription {
|
||||
random := make([]byte, 7)
|
||||
rand.Read(random)
|
||||
id := hex.EncodeToString(random)
|
||||
return r.subscribe(id, filters)
|
||||
}
|
||||
|
||||
func (r *Relay) subscribe(id string, filters Filters) *Subscription {
|
||||
sub := Subscription{}
|
||||
sub.id = id
|
||||
|
||||
sub.Events = make(chan Event)
|
||||
r.subscriptions.Store(sub.id, &sub)
|
||||
|
||||
sub.Sub(filters)
|
||||
return &sub
|
||||
}
|
||||
|
||||
func (r *Relay) Close() error {
|
||||
return r.Connection.Close()
|
||||
}
|
211
relaypool.go
211
relaypool.go
@@ -3,37 +3,12 @@ package nostr
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
s "github.com/SaveTheRbtz/generic-sync-map-go"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Status int
|
||||
|
||||
const (
|
||||
PublishStatusSent Status = 0
|
||||
PublishStatusFailed Status = -1
|
||||
PublishStatusSucceeded Status = 1
|
||||
)
|
||||
|
||||
func (s Status) String() string {
|
||||
switch s {
|
||||
case PublishStatusSent:
|
||||
return "sent"
|
||||
case PublishStatusFailed:
|
||||
return "failed"
|
||||
case PublishStatusSucceeded:
|
||||
return "success"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
type PublishStatus struct {
|
||||
Relay string
|
||||
Status Status
|
||||
@@ -42,9 +17,10 @@ type PublishStatus struct {
|
||||
type RelayPool struct {
|
||||
SecretKey *string
|
||||
|
||||
Relays s.MapOf[string, RelayPoolPolicy]
|
||||
websockets s.MapOf[string, *Connection]
|
||||
subscriptions s.MapOf[string, *Subscription]
|
||||
Policies s.MapOf[string, RelayPoolPolicy]
|
||||
Relays s.MapOf[string, *Relay]
|
||||
subscriptions s.MapOf[string, Filters]
|
||||
eventStreams s.MapOf[string, chan EventMessage]
|
||||
|
||||
Notices chan *NoticeMessage
|
||||
}
|
||||
@@ -75,9 +51,8 @@ type NoticeMessage struct {
|
||||
// New creates a new RelayPool with no relays in it
|
||||
func NewRelayPool() *RelayPool {
|
||||
return &RelayPool{
|
||||
Relays: s.MapOf[string, RelayPoolPolicy]{},
|
||||
websockets: s.MapOf[string, *Connection]{},
|
||||
subscriptions: s.MapOf[string, *Subscription]{},
|
||||
Policies: s.MapOf[string, RelayPoolPolicy]{},
|
||||
Relays: s.MapOf[string, *Relay]{},
|
||||
|
||||
Notices: make(chan *NoticeMessage),
|
||||
}
|
||||
@@ -90,101 +65,23 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error {
|
||||
policy = SimplePolicy{Read: true, Write: true}
|
||||
}
|
||||
|
||||
nm := NormalizeURL(url)
|
||||
if nm == "" {
|
||||
return fmt.Errorf("invalid relay URL '%s'", url)
|
||||
}
|
||||
relay := NewRelay(url)
|
||||
r.Policies.Store(relay.URL, policy)
|
||||
r.Relays.Store(relay.URL, relay)
|
||||
|
||||
socket, _, err := websocket.DefaultDialer.Dial(NormalizeURL(url), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening websocket to '%s': %w", nm, err)
|
||||
}
|
||||
r.subscriptions.Range(func(id string, filters Filters) bool {
|
||||
sub := relay.subscribe(id, filters)
|
||||
eventStream, _ := r.eventStreams.Load(id)
|
||||
|
||||
conn := NewConnection(socket)
|
||||
go func(sub *Subscription) {
|
||||
for evt := range sub.Events {
|
||||
eventStream <- EventMessage{Relay: relay.URL, Event: evt}
|
||||
}
|
||||
}(sub)
|
||||
|
||||
r.Relays.Store(nm, policy)
|
||||
r.websockets.Store(nm, conn)
|
||||
|
||||
r.subscriptions.Range(func(_ string, sub *Subscription) bool {
|
||||
sub.addRelay(nm, conn)
|
||||
return true
|
||||
})
|
||||
|
||||
go func() {
|
||||
for {
|
||||
typ, message, err := conn.socket.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("read error: ", err)
|
||||
return
|
||||
}
|
||||
if typ == websocket.PingMessage {
|
||||
conn.WriteMessage(websocket.PongMessage, nil)
|
||||
continue
|
||||
}
|
||||
|
||||
if typ != websocket.TextMessage || len(message) == 0 || message[0] != '[' {
|
||||
continue
|
||||
}
|
||||
|
||||
var jsonMessage []json.RawMessage
|
||||
err = json.Unmarshal(message, &jsonMessage)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(jsonMessage) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
var label string
|
||||
json.Unmarshal(jsonMessage[0], &label)
|
||||
|
||||
switch label {
|
||||
case "NOTICE":
|
||||
var content string
|
||||
json.Unmarshal(jsonMessage[1], &content)
|
||||
r.Notices <- &NoticeMessage{
|
||||
Relay: nm,
|
||||
Message: content,
|
||||
}
|
||||
case "EVENT":
|
||||
if len(jsonMessage) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
var channel string
|
||||
json.Unmarshal(jsonMessage[1], &channel)
|
||||
if subscription, ok := r.subscriptions.Load(channel); ok {
|
||||
var event Event
|
||||
json.Unmarshal(jsonMessage[2], &event)
|
||||
|
||||
// check signature of all received events, ignore invalid
|
||||
ok, err := event.CheckSignature()
|
||||
if !ok {
|
||||
errmsg := ""
|
||||
if err != nil {
|
||||
errmsg = err.Error()
|
||||
}
|
||||
log.Printf("bad signature: %s", errmsg)
|
||||
continue
|
||||
}
|
||||
|
||||
// check if the event matches the desired filter, ignore otherwise
|
||||
if !subscription.filters.Match(&event) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !subscription.stopped {
|
||||
subscription.Events <- EventMessage{
|
||||
Relay: nm,
|
||||
Event: event,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -192,41 +89,36 @@ func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error {
|
||||
func (r *RelayPool) Remove(url string) {
|
||||
nm := NormalizeURL(url)
|
||||
|
||||
r.subscriptions.Range(func(_ string, sub *Subscription) bool {
|
||||
sub.removeRelay(nm)
|
||||
return true
|
||||
})
|
||||
|
||||
if conn, ok := r.websockets.Load(nm); ok {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
r.Relays.Delete(nm)
|
||||
r.websockets.Delete(nm)
|
||||
r.Policies.Delete(nm)
|
||||
|
||||
if relay, ok := r.Relays.Load(nm); ok {
|
||||
relay.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RelayPool) Sub(filters Filters) *Subscription {
|
||||
func (r *RelayPool) Sub(filters Filters) (string, chan EventMessage) {
|
||||
random := make([]byte, 7)
|
||||
rand.Read(random)
|
||||
id := hex.EncodeToString(random)
|
||||
|
||||
subscription := Subscription{}
|
||||
subscription.channel = hex.EncodeToString(random)
|
||||
subscription.relays = s.MapOf[string, *Connection]{}
|
||||
r.subscriptions.Store(id, filters)
|
||||
eventStream := make(chan EventMessage)
|
||||
r.eventStreams.Store(id, eventStream)
|
||||
|
||||
r.Relays.Range(func(relay string, policy RelayPoolPolicy) bool {
|
||||
if policy.ShouldRead(filters) {
|
||||
if ws, ok := r.websockets.Load(relay); ok {
|
||||
subscription.relays.Store(relay, ws)
|
||||
r.Relays.Range(func(_ string, relay *Relay) bool {
|
||||
sub := relay.subscribe(id, filters)
|
||||
|
||||
go func(sub *Subscription) {
|
||||
for evt := range sub.Events {
|
||||
eventStream <- EventMessage{Relay: relay.URL, Event: evt}
|
||||
}
|
||||
}
|
||||
}(sub)
|
||||
|
||||
return true
|
||||
})
|
||||
subscription.Events = make(chan EventMessage)
|
||||
subscription.UniqueEvents = make(chan Event)
|
||||
r.subscriptions.Store(subscription.channel, &subscription)
|
||||
|
||||
subscription.Sub(filters)
|
||||
return &subscription
|
||||
return id, eventStream
|
||||
}
|
||||
|
||||
func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error) {
|
||||
@@ -251,35 +143,16 @@ func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error)
|
||||
}
|
||||
}
|
||||
|
||||
r.websockets.Range(func(relay string, conn *Connection) bool {
|
||||
if r, ok := r.Relays.Load(relay); !ok || !r.ShouldWrite(evt) {
|
||||
r.Relays.Range(func(url string, relay *Relay) bool {
|
||||
if r, ok := r.Policies.Load(url); !ok || !r.ShouldWrite(evt) {
|
||||
return true
|
||||
}
|
||||
|
||||
go func(relay string, conn *Connection) {
|
||||
err := conn.WriteJSON([]interface{}{"EVENT", evt})
|
||||
if err != nil {
|
||||
log.Printf("error sending event to '%s': %s", relay, err.Error())
|
||||
status <- PublishStatus{relay, PublishStatusFailed}
|
||||
go func(relay *Relay) {
|
||||
for resultStatus := range relay.Publish(*evt) {
|
||||
status <- PublishStatus{relay.URL, resultStatus}
|
||||
}
|
||||
status <- PublishStatus{relay, PublishStatusSent}
|
||||
|
||||
subscription := r.Sub(Filters{Filter{IDs: []string{evt.ID}}})
|
||||
for {
|
||||
select {
|
||||
case event := <-subscription.UniqueEvents:
|
||||
if event.ID == evt.ID {
|
||||
status <- PublishStatus{relay, PublishStatusSucceeded}
|
||||
break
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
}(relay, conn)
|
||||
}(relay)
|
||||
|
||||
return true
|
||||
})
|
||||
|
@@ -1,18 +1,11 @@
|
||||
package nostr
|
||||
|
||||
import (
|
||||
s "github.com/SaveTheRbtz/generic-sync-map-go"
|
||||
)
|
||||
|
||||
type Subscription struct {
|
||||
channel string
|
||||
relays s.MapOf[string, *Connection]
|
||||
id string
|
||||
conn *Connection
|
||||
|
||||
filters Filters
|
||||
Events chan EventMessage
|
||||
|
||||
started bool
|
||||
UniqueEvents chan Event
|
||||
Events chan Event
|
||||
|
||||
stopped bool
|
||||
}
|
||||
@@ -22,78 +15,22 @@ type EventMessage struct {
|
||||
Relay string
|
||||
}
|
||||
|
||||
func (subscription Subscription) Unsub() {
|
||||
subscription.relays.Range(func(_ string, conn *Connection) bool {
|
||||
conn.WriteJSON([]interface{}{
|
||||
"CLOSE",
|
||||
subscription.channel,
|
||||
})
|
||||
return true
|
||||
})
|
||||
func (sub Subscription) Unsub() {
|
||||
sub.conn.WriteJSON([]interface{}{"CLOSE", sub.id})
|
||||
|
||||
subscription.stopped = true
|
||||
if subscription.Events != nil {
|
||||
close(subscription.Events)
|
||||
}
|
||||
if subscription.UniqueEvents != nil {
|
||||
close(subscription.UniqueEvents)
|
||||
sub.stopped = true
|
||||
if sub.Events != nil {
|
||||
close(sub.Events)
|
||||
}
|
||||
}
|
||||
|
||||
func (subscription *Subscription) Sub(filters Filters) {
|
||||
subscription.filters = filters
|
||||
func (sub *Subscription) Sub(filters Filters) {
|
||||
sub.filters = filters
|
||||
|
||||
subscription.relays.Range(func(_ string, conn *Connection) bool {
|
||||
message := []interface{}{
|
||||
"REQ",
|
||||
subscription.channel,
|
||||
}
|
||||
for _, filter := range subscription.filters {
|
||||
message = append(message, filter)
|
||||
}
|
||||
|
||||
conn.WriteJSON(message)
|
||||
return true
|
||||
})
|
||||
|
||||
if !subscription.started {
|
||||
go subscription.startHandlingUnique()
|
||||
}
|
||||
}
|
||||
|
||||
func (subscription Subscription) startHandlingUnique() {
|
||||
seen := make(map[string]struct{})
|
||||
for em := range subscription.Events {
|
||||
if _, ok := seen[em.Event.ID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[em.Event.ID] = struct{}{}
|
||||
if !subscription.stopped {
|
||||
subscription.UniqueEvents <- em.Event
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (subscription Subscription) removeRelay(relay string) {
|
||||
if conn, ok := subscription.relays.Load(relay); ok {
|
||||
subscription.relays.Delete(relay)
|
||||
conn.WriteJSON([]interface{}{
|
||||
"CLOSE",
|
||||
subscription.channel,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (subscription Subscription) addRelay(relay string, conn *Connection) {
|
||||
subscription.relays.Store(relay, conn)
|
||||
|
||||
message := []interface{}{
|
||||
"REQ",
|
||||
subscription.channel,
|
||||
}
|
||||
for _, filter := range subscription.filters {
|
||||
message := []interface{}{"REQ", sub.id}
|
||||
for _, filter := range sub.filters {
|
||||
message = append(message, filter)
|
||||
}
|
||||
|
||||
conn.WriteJSON(message)
|
||||
sub.conn.WriteJSON(message)
|
||||
}
|
||||
|
Reference in New Issue
Block a user