discovery: pass contexts to NetworkPeerBootstrapper methods

We will later implement this interface with a backing RPC connection and
so it makes sense to pass a context through for cancellation.
This commit is contained in:
Elle Mouton
2024-11-12 08:41:24 +02:00
parent 28415f5ef2
commit 0f33d41c55

View File

@@ -2,6 +2,7 @@ package discovery
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"errors" "errors"
@@ -36,12 +37,13 @@ type NetworkPeerBootstrapper interface {
// denotes how many valid peer addresses to return. The passed set of // denotes how many valid peer addresses to return. The passed set of
// node nodes allows the caller to ignore a set of nodes perhaps // node nodes allows the caller to ignore a set of nodes perhaps
// because they already have connections established. // because they already have connections established.
SampleNodeAddrs(numAddrs uint32, SampleNodeAddrs(ctx context.Context, numAddrs uint32,
ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress,
error)
// Name returns a human readable string which names the concrete // Name returns a human readable string which names the concrete
// implementation of the NetworkPeerBootstrapper. // implementation of the NetworkPeerBootstrapper.
Name() string Name(ctx context.Context) string
} }
// MultiSourceBootstrap attempts to utilize a set of NetworkPeerBootstrapper // MultiSourceBootstrap attempts to utilize a set of NetworkPeerBootstrapper
@@ -53,6 +55,8 @@ type NetworkPeerBootstrapper interface {
func MultiSourceBootstrap(ignore map[autopilot.NodeID]struct{}, numAddrs uint32, func MultiSourceBootstrap(ignore map[autopilot.NodeID]struct{}, numAddrs uint32,
bootstrappers ...NetworkPeerBootstrapper) ([]*lnwire.NetAddress, error) { bootstrappers ...NetworkPeerBootstrapper) ([]*lnwire.NetAddress, error) {
ctx := context.TODO()
// We'll randomly shuffle our bootstrappers before querying them in // We'll randomly shuffle our bootstrappers before querying them in
// order to avoid from querying the same bootstrapper method over and // order to avoid from querying the same bootstrapper method over and
// over, as some of these might tend to provide better/worse results // over, as some of these might tend to provide better/worse results
@@ -67,19 +71,23 @@ func MultiSourceBootstrap(ignore map[autopilot.NodeID]struct{}, numAddrs uint32,
break break
} }
log.Infof("Attempting to bootstrap with: %v", bootstrapper.Name()) name := bootstrapper.Name(ctx)
log.Infof("Attempting to bootstrap with: %v", name)
// If we still need additional addresses, then we'll compute // If we still need additional addresses, then we'll compute
// the number of address remaining that we need to fetch. // the number of address remaining that we need to fetch.
numAddrsLeft := numAddrs - uint32(len(addrs)) numAddrsLeft := numAddrs - uint32(len(addrs))
log.Tracef("Querying for %v addresses", numAddrsLeft) log.Tracef("Querying for %v addresses", numAddrsLeft)
netAddrs, err := bootstrapper.SampleNodeAddrs(numAddrsLeft, ignore) netAddrs, err := bootstrapper.SampleNodeAddrs(
ctx, numAddrsLeft, ignore,
)
if err != nil { if err != nil {
// If we encounter an error with a bootstrapper, then // If we encounter an error with a bootstrapper, then
// we'll continue on to the next available // we'll continue on to the next available
// bootstrapper. // bootstrapper.
log.Errorf("Unable to query bootstrapper %v: %v", log.Errorf("Unable to query bootstrapper %v: %v", name,
bootstrapper.Name(), err) err)
continue continue
} }
@@ -152,8 +160,9 @@ func NewGraphBootstrapper(cg autopilot.ChannelGraph) (NetworkPeerBootstrapper, e
// many valid peer addresses to return. // many valid peer addresses to return.
// //
// NOTE: Part of the NetworkPeerBootstrapper interface. // NOTE: Part of the NetworkPeerBootstrapper interface.
func (c *ChannelGraphBootstrapper) SampleNodeAddrs(numAddrs uint32, func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context,
ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) { numAddrs uint32, ignore map[autopilot.NodeID]struct{}) (
[]*lnwire.NetAddress, error) {
// We'll merge the ignore map with our currently selected map in order // We'll merge the ignore map with our currently selected map in order
// to ensure we don't return any duplicate nodes. // to ensure we don't return any duplicate nodes.
@@ -269,7 +278,7 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(numAddrs uint32,
// of the NetworkPeerBootstrapper. // of the NetworkPeerBootstrapper.
// //
// NOTE: Part of the NetworkPeerBootstrapper interface. // NOTE: Part of the NetworkPeerBootstrapper interface.
func (c *ChannelGraphBootstrapper) Name() string { func (c *ChannelGraphBootstrapper) Name(_ context.Context) string {
return "Authenticated Channel Graph" return "Authenticated Channel Graph"
} }
@@ -382,8 +391,9 @@ func (d *DNSSeedBootstrapper) fallBackSRVLookup(soaShim string,
// network peer bootstrapper source. The num addrs field passed in denotes how // network peer bootstrapper source. The num addrs field passed in denotes how
// many valid peer addresses to return. The set of DNS seeds are used // many valid peer addresses to return. The set of DNS seeds are used
// successively to retrieve eligible target nodes. // successively to retrieve eligible target nodes.
func (d *DNSSeedBootstrapper) SampleNodeAddrs(numAddrs uint32, func (d *DNSSeedBootstrapper) SampleNodeAddrs(_ context.Context,
ignore map[autopilot.NodeID]struct{}) ([]*lnwire.NetAddress, error) { numAddrs uint32, ignore map[autopilot.NodeID]struct{}) (
[]*lnwire.NetAddress, error) {
var netAddrs []*lnwire.NetAddress var netAddrs []*lnwire.NetAddress
@@ -532,6 +542,6 @@ search:
// Name returns a human readable string which names the concrete // Name returns a human readable string which names the concrete
// implementation of the NetworkPeerBootstrapper. // implementation of the NetworkPeerBootstrapper.
func (d *DNSSeedBootstrapper) Name() string { func (d *DNSSeedBootstrapper) Name(_ context.Context) string {
return fmt.Sprintf("BOLT-0010 DNS Seed: %v", d.dnsSeeds) return fmt.Sprintf("BOLT-0010 DNS Seed: %v", d.dnsSeeds)
} }