diff --git a/relaypool.go b/relaypool.go index 540c34f..631cbd3 100644 --- a/relaypool.go +++ b/relaypool.go @@ -34,9 +34,9 @@ type RelayPool struct { Notices chan *NoticeMessage } -type RelayPoolPolicy struct { - SimplePolicy - ReadSpecific map[string]SimplePolicy +type RelayPoolPolicy interface { + ShouldRead(EventFilters) bool + ShouldWrite(*Event) bool } type SimplePolicy struct { @@ -44,6 +44,14 @@ type SimplePolicy struct { Write bool } +func (s SimplePolicy) ShouldRead(_ EventFilters) bool { + return s.Read +} + +func (s SimplePolicy) ShouldWrite(_ *Event) bool { + return s.Write +} + type NoticeMessage struct { Message string Relay string @@ -62,9 +70,9 @@ func NewRelayPool() *RelayPool { // Add adds a new relay to the pool, if policy is nil, it will be a simple // read+write policy. -func (r *RelayPool) Add(url string, policy *RelayPoolPolicy) error { +func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error { if policy == nil { - policy = &RelayPoolPolicy{SimplePolicy: SimplePolicy{Read: true, Write: true}} + policy = SimplePolicy{Read: true, Write: true} } nm := NormalizeURL(url) @@ -77,7 +85,7 @@ func (r *RelayPool) Add(url string, policy *RelayPoolPolicy) error { return fmt.Errorf("error opening websocket to '%s': %w", nm, err) } - r.Relays[nm] = *policy + r.Relays[nm] = policy r.websockets[nm] = conn for _, sub := range r.subscriptions { @@ -177,7 +185,7 @@ func (r *RelayPool) Sub(filters EventFilters) *Subscription { subscription.channel = hex.EncodeToString(random) subscription.relays = make(map[string]*websocket.Conn) for relay, policy := range r.Relays { - if policy.Read { + if policy.ShouldRead(filters) { ws := r.websockets[relay] subscription.relays[relay] = ws } @@ -213,6 +221,10 @@ func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error) } for relay, conn := range r.websockets { + if !r.Relays[relay].ShouldWrite(evt) { + continue + } + go func(relay string, conn *websocket.Conn) { err := conn.WriteJSON([]interface{}{"EVENT", evt}) if err != nil {