diff --git a/pool.go b/pool.go index 8de7142..1007a43 100644 --- a/pool.go +++ b/pool.go @@ -8,14 +8,21 @@ import ( ) type SimplePool struct { - Relays map[string]*Relay + Relays map[string]*Relay + Context context.Context - mutex sync.Mutex + mutex sync.Mutex + cancel context.CancelFunc } -func NewSimplePool() *SimplePool { +func NewSimplePool(ctx context.Context) *SimplePool { + ctx, cancel := context.WithCancel(ctx) + return &SimplePool{ Relays: make(map[string]*Relay), + + Context: ctx, + cancel: cancel, } } @@ -26,13 +33,13 @@ func (pool *SimplePool) EnsureRelay(url string) *Relay { defer pool.mutex.Unlock() relay, ok := pool.Relays[nm] - if ok { + if ok && relay.ConnectionContext.Err() == nil { // already connected, unlock and return return relay } else { var err error - // when connecting to a relay we want the connection to persist forever if possible, so use a new context - relay, err = RelayConnect(context.Background(), nm) + // we use this ctx here so when the pool dies everything dies + relay, err = RelayConnect(pool.Context, nm) if err != nil { return nil }