diff --git a/pool.go b/pool.go index 4abf935..d899809 100644 --- a/pool.go +++ b/pool.go @@ -23,6 +23,8 @@ type SimplePool struct { authHandler func(*Event) error cancel context.CancelFunc + eventMiddleware []func(IncomingEvent) + // custom things not often used SignatureChecker func(Event) bool } @@ -42,8 +44,7 @@ func (ie IncomingEvent) String() string { } type PoolOption interface { - IsPoolOption() - Apply(*SimplePool) + ApplyPoolOption(*SimplePool) } func NewSimplePool(ctx context.Context, opts ...PoolOption) *SimplePool { @@ -57,7 +58,7 @@ func NewSimplePool(ctx context.Context, opts ...PoolOption) *SimplePool { } for _, opt := range opts { - opt.Apply(pool) + opt.ApplyPoolOption(pool) } return pool @@ -68,12 +69,22 @@ func NewSimplePool(ctx context.Context, opts ...PoolOption) *SimplePool { // with the "auth-required:" prefix, only once for each relay type WithAuthHandler func(authEvent *Event) error -func (_ WithAuthHandler) IsPoolOption() {} -func (h WithAuthHandler) Apply(pool *SimplePool) { +func (h WithAuthHandler) ApplyPoolOption(pool *SimplePool) { pool.authHandler = h } -var _ PoolOption = (WithAuthHandler)(nil) +// WithEventMiddleware is a function that will be called with all events received. +// more than one can be passed at a time. +type WithEventMiddleware func(IncomingEvent) + +func (h WithEventMiddleware) ApplyPoolOption(pool *SimplePool) { + pool.eventMiddleware = append(pool.eventMiddleware, h) +} + +var ( + _ PoolOption = (WithAuthHandler)(nil) + _ PoolOption = (WithEventMiddleware)(nil) +) func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) { nm := NormalizeURL(url) @@ -89,7 +100,7 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) { ctx, cancel := context.WithTimeout(pool.Context, time.Second*15) defer cancel() - opts := make([]RelayOption, 0, 1) + opts := make([]RelayOption, 0, 1+len(pool.eventMiddleware)) if pool.SignatureChecker != nil { opts = append(opts, WithSignatureChecker(pool.SignatureChecker)) } @@ -186,13 +197,20 @@ func (pool *SimplePool) subMany(ctx context.Context, urls []string, filters Filt } goto reconnect } + + ie := IncomingEvent{Event: evt, Relay: relay} + for _, mh := range pool.eventMiddleware { + mh(ie) + } + if unique { if _, seen := seenAlready.LoadOrStore(evt.ID, evt.CreatedAt); seen { continue } } + select { - case events <- IncomingEvent{Event: evt, Relay: relay}: + case events <- ie: case <-ctx.Done(): } case <-ticker.C: @@ -298,6 +316,11 @@ func (pool *SimplePool) subManyEose(ctx context.Context, urls []string, filters return } + ie := IncomingEvent{Event: evt, Relay: relay} + for _, mh := range pool.eventMiddleware { + mh(ie) + } + if unique { if _, seen := seenAlready.LoadOrStore(evt.ID, true); seen { continue @@ -305,7 +328,7 @@ func (pool *SimplePool) subManyEose(ctx context.Context, urls []string, filters } select { - case events <- IncomingEvent{Event: evt, Relay: relay}: + case events <- ie: case <-ctx.Done(): return }