support performing AUTH in the middle of SimplePool's subMany*

This commit is contained in:
fiatjaf
2023-12-07 21:37:41 -03:00
parent f8fa490293
commit b2170efb5a

57
pool.go
View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"strings"
"sync" "sync"
"time" "time"
@@ -18,7 +19,8 @@ type SimplePool struct {
Relays *xsync.MapOf[string, *Relay] Relays *xsync.MapOf[string, *Relay]
Context context.Context Context context.Context
cancel context.CancelFunc authHandler func(*Event) error
cancel context.CancelFunc
} }
type IncomingEvent struct { type IncomingEvent struct {
@@ -26,17 +28,40 @@ type IncomingEvent struct {
Relay *Relay Relay *Relay
} }
func NewSimplePool(ctx context.Context) *SimplePool { type PoolOption interface {
IsPoolOption()
Apply(*SimplePool)
}
func NewSimplePool(ctx context.Context, opts ...PoolOption) *SimplePool {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return &SimplePool{ pool := &SimplePool{
Relays: xsync.NewMapOf[*Relay](), Relays: xsync.NewMapOf[*Relay](),
Context: ctx, Context: ctx,
cancel: cancel, cancel: cancel,
} }
for _, opt := range opts {
opt.Apply(pool)
}
return pool
} }
// WithAuthHandler must be a function that signs the auth event when called.
// it will be called whenever any relay in the pool returns a `CLOSED` message
// with the "auth-required:" prefix, only once for each relay
type WithAuthHandler func(authEvent *Event) error
func (_ WithAuthHandler) IsPoolOption() {}
func (h WithAuthHandler) Apply(pool *SimplePool) {
pool.authHandler = h
}
var _ PoolOption = (WithAuthHandler)(nil)
func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) { func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
nm := NormalizeURL(url) nm := NormalizeURL(url)
@@ -91,6 +116,7 @@ func (pool *SimplePool) subMany(ctx context.Context, urls []string, filters Filt
cancel() cancel()
}() }()
hasAuthed := false
interval := 3 * time.Second interval := 3 * time.Second
for { for {
select { select {
@@ -105,7 +131,9 @@ func (pool *SimplePool) subMany(ctx context.Context, urls []string, filters Filt
if err != nil { if err != nil {
goto reconnect goto reconnect
} }
hasAuthed = false
subscribe:
sub, err = relay.Subscribe(ctx, filters) sub, err = relay.Subscribe(ctx, filters)
if err != nil { if err != nil {
goto reconnect goto reconnect
@@ -149,7 +177,15 @@ func (pool *SimplePool) subMany(ctx context.Context, urls []string, filters Filt
}) })
} }
case reason := <-sub.ClosedReason: case reason := <-sub.ClosedReason:
log.Printf("CLOSED from %s: '%s'\n", nm, reason) if strings.HasPrefix(reason, "auth-required:") && pool.authHandler != nil && !hasAuthed {
// relay is requesting auth. if we can we will perform auth and try again
if err := relay.Auth(ctx, pool.authHandler); err == nil {
hasAuthed = true // so we don't keep doing AUTH again and again
goto subscribe
}
} else {
log.Printf("CLOSED from %s: '%s'\n", nm, reason)
}
return return
case <-ctx.Done(): case <-ctx.Done():
return return
@@ -202,6 +238,9 @@ func (pool *SimplePool) subManyEose(ctx context.Context, urls []string, filters
return return
} }
hasAuthed := false
subscribe:
sub, err := relay.Subscribe(ctx, filters) sub, err := relay.Subscribe(ctx, filters)
if sub == nil { if sub == nil {
debugLogf("error subscribing to %s with %v: %s", relay, filters, err) debugLogf("error subscribing to %s with %v: %s", relay, filters, err)
@@ -215,7 +254,15 @@ func (pool *SimplePool) subManyEose(ctx context.Context, urls []string, filters
case <-sub.EndOfStoredEvents: case <-sub.EndOfStoredEvents:
return return
case reason := <-sub.ClosedReason: case reason := <-sub.ClosedReason:
log.Printf("CLOSED from %s: '%s'\n", nm, reason) if strings.HasPrefix(reason, "auth-required:") && pool.authHandler != nil && !hasAuthed {
// relay is requesting auth. if we can we will perform auth and try again
if err := relay.Auth(ctx, pool.authHandler); err == nil {
hasAuthed = true // so we don't keep doing AUTH again and again
goto subscribe
}
} else {
log.Printf("CLOSED from %s: '%s'\n", nm, reason)
}
return return
case evt, more := <-sub.Events: case evt, more := <-sub.Events:
if !more { if !more {