relay policies as interfaces, SimplePolicy as the most basic implementation.

This commit is contained in:
fiatjaf
2022-01-02 08:50:53 -03:00
parent d131e8460e
commit 140edc693c

View File

@@ -34,9 +34,9 @@ type RelayPool struct {
Notices chan *NoticeMessage Notices chan *NoticeMessage
} }
type RelayPoolPolicy struct { type RelayPoolPolicy interface {
SimplePolicy ShouldRead(EventFilters) bool
ReadSpecific map[string]SimplePolicy ShouldWrite(*Event) bool
} }
type SimplePolicy struct { type SimplePolicy struct {
@@ -44,6 +44,14 @@ type SimplePolicy struct {
Write bool Write bool
} }
func (s SimplePolicy) ShouldRead(_ EventFilters) bool {
return s.Read
}
func (s SimplePolicy) ShouldWrite(_ *Event) bool {
return s.Write
}
type NoticeMessage struct { type NoticeMessage struct {
Message string Message string
Relay string Relay string
@@ -62,9 +70,9 @@ func NewRelayPool() *RelayPool {
// Add adds a new relay to the pool, if policy is nil, it will be a simple // Add adds a new relay to the pool, if policy is nil, it will be a simple
// read+write policy. // read+write policy.
func (r *RelayPool) Add(url string, policy *RelayPoolPolicy) error { func (r *RelayPool) Add(url string, policy RelayPoolPolicy) error {
if policy == nil { if policy == nil {
policy = &RelayPoolPolicy{SimplePolicy: SimplePolicy{Read: true, Write: true}} policy = SimplePolicy{Read: true, Write: true}
} }
nm := NormalizeURL(url) nm := NormalizeURL(url)
@@ -77,7 +85,7 @@ func (r *RelayPool) Add(url string, policy *RelayPoolPolicy) error {
return fmt.Errorf("error opening websocket to '%s': %w", nm, err) return fmt.Errorf("error opening websocket to '%s': %w", nm, err)
} }
r.Relays[nm] = *policy r.Relays[nm] = policy
r.websockets[nm] = conn r.websockets[nm] = conn
for _, sub := range r.subscriptions { for _, sub := range r.subscriptions {
@@ -177,7 +185,7 @@ func (r *RelayPool) Sub(filters EventFilters) *Subscription {
subscription.channel = hex.EncodeToString(random) subscription.channel = hex.EncodeToString(random)
subscription.relays = make(map[string]*websocket.Conn) subscription.relays = make(map[string]*websocket.Conn)
for relay, policy := range r.Relays { for relay, policy := range r.Relays {
if policy.Read { if policy.ShouldRead(filters) {
ws := r.websockets[relay] ws := r.websockets[relay]
subscription.relays[relay] = ws subscription.relays[relay] = ws
} }
@@ -213,6 +221,10 @@ func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error)
} }
for relay, conn := range r.websockets { for relay, conn := range r.websockets {
if !r.Relays[relay].ShouldWrite(evt) {
continue
}
go func(relay string, conn *websocket.Conn) { go func(relay string, conn *websocket.Conn) {
err := conn.WriteJSON([]interface{}{"EVENT", evt}) err := conn.WriteJSON([]interface{}{"EVENT", evt})
if err != nil { if err != nil {