multi: extract address source into interface

As a preparation to have the method for querying the addresses of a node
separate from the channel state, we extract that method out into its own
interface.
This commit is contained in:
Oliver Gugger 2021-09-21 19:18:16 +02:00
parent c1f686f860
commit ddea833d31
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
4 changed files with 29 additions and 16 deletions

View File

@ -21,7 +21,11 @@ type LiveChannelSource interface {
// passed chanPoint. Optionally an existing db tx can be supplied. // passed chanPoint. Optionally an existing db tx can be supplied.
FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) ( FetchChannel(tx kvdb.RTx, chanPoint wire.OutPoint) (
*channeldb.OpenChannel, error) *channeldb.OpenChannel, error)
}
// AddressSource is an interface that allows us to query for the set of
// addresses a node can be connected to.
type AddressSource interface {
// AddrsForNode returns all known addresses for the target node public // AddrsForNode returns all known addresses for the target node public
// key. // key.
AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error)
@ -31,15 +35,15 @@ type LiveChannelSource interface {
// passed open channel. The backup includes all information required to restore // passed open channel. The backup includes all information required to restore
// the channel, as well as addressing information so we can find the peer and // the channel, as well as addressing information so we can find the peer and
// reconnect to them to initiate the protocol. // reconnect to them to initiate the protocol.
func assembleChanBackup(chanSource LiveChannelSource, func assembleChanBackup(addrSource AddressSource,
openChan *channeldb.OpenChannel) (*Single, error) { openChan *channeldb.OpenChannel) (*Single, error) {
log.Debugf("Crafting backup for ChannelPoint(%v)", log.Debugf("Crafting backup for ChannelPoint(%v)",
openChan.FundingOutpoint) openChan.FundingOutpoint)
// First, we'll query the channel source to obtain all the addresses // First, we'll query the channel source to obtain all the addresses
// that are are associated with the peer for this channel. // that are associated with the peer for this channel.
nodeAddrs, err := chanSource.AddrsForNode(openChan.IdentityPub) nodeAddrs, err := addrSource.AddrsForNode(openChan.IdentityPub)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -52,8 +56,8 @@ func assembleChanBackup(chanSource LiveChannelSource,
// FetchBackupForChan attempts to create a plaintext static channel backup for // FetchBackupForChan attempts to create a plaintext static channel backup for
// the target channel identified by its channel point. If we're unable to find // the target channel identified by its channel point. If we're unable to find
// the target channel, then an error will be returned. // the target channel, then an error will be returned.
func FetchBackupForChan(chanPoint wire.OutPoint, func FetchBackupForChan(chanPoint wire.OutPoint, chanSource LiveChannelSource,
chanSource LiveChannelSource) (*Single, error) { addrSource AddressSource) (*Single, error) {
// First, we'll query the channel source to see if the channel is known // First, we'll query the channel source to see if the channel is known
// and open within the database. // and open within the database.
@ -66,7 +70,7 @@ func FetchBackupForChan(chanPoint wire.OutPoint,
// Once we have the target channel, we can assemble the backup using // Once we have the target channel, we can assemble the backup using
// the source to obtain any extra information that we may need. // the source to obtain any extra information that we may need.
staticChanBackup, err := assembleChanBackup(chanSource, targetChan) staticChanBackup, err := assembleChanBackup(addrSource, targetChan)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to create chan backup: %v", err) return nil, fmt.Errorf("unable to create chan backup: %v", err)
} }
@ -76,7 +80,9 @@ func FetchBackupForChan(chanPoint wire.OutPoint,
// FetchStaticChanBackups will return a plaintext static channel back up for // FetchStaticChanBackups will return a plaintext static channel back up for
// all known active/open channels within the passed channel source. // all known active/open channels within the passed channel source.
func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) { func FetchStaticChanBackups(chanSource LiveChannelSource,
addrSource AddressSource) ([]Single, error) {
// First, we'll query the backup source for information concerning all // First, we'll query the backup source for information concerning all
// currently open and available channels. // currently open and available channels.
openChans, err := chanSource.FetchAllChannels() openChans, err := chanSource.FetchAllChannels()
@ -89,7 +95,7 @@ func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) {
// channel. // channel.
staticChanBackups := make([]Single, 0, len(openChans)) staticChanBackups := make([]Single, 0, len(openChans))
for _, openChan := range openChans { for _, openChan := range openChans {
chanBackup, err := assembleChanBackup(chanSource, openChan) chanBackup, err := assembleChanBackup(addrSource, openChan)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -124,7 +124,9 @@ func TestFetchBackupForChan(t *testing.T) {
}, },
} }
for i, testCase := range testCases { for i, testCase := range testCases {
_, err := FetchBackupForChan(testCase.chanPoint, chanSource) _, err := FetchBackupForChan(
testCase.chanPoint, chanSource, chanSource,
)
switch { switch {
// If this is a valid test case, and we failed, then we'll // If this is a valid test case, and we failed, then we'll
// return an error. // return an error.
@ -167,7 +169,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
// With the channel source populated, we'll now attempt to create a set // With the channel source populated, we'll now attempt to create a set
// of backups for all the channels. This should succeed, as all items // of backups for all the channels. This should succeed, as all items
// are populated within the channel source. // are populated within the channel source.
backups, err := FetchStaticChanBackups(chanSource) backups, err := FetchStaticChanBackups(chanSource, chanSource)
if err != nil { if err != nil {
t.Fatalf("unable to create chan back ups: %v", err) t.Fatalf("unable to create chan back ups: %v", err)
} }
@ -184,7 +186,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
copy(n[:], randomChan2.IdentityPub.SerializeCompressed()) copy(n[:], randomChan2.IdentityPub.SerializeCompressed())
delete(chanSource.addrs, n) delete(chanSource.addrs, n)
_, err = FetchStaticChanBackups(chanSource) _, err = FetchStaticChanBackups(chanSource, chanSource)
if err == nil { if err == nil {
t.Fatalf("query with incomplete information should fail") t.Fatalf("query with incomplete information should fail")
} }
@ -193,7 +195,7 @@ func TestFetchStaticChanBackups(t *testing.T) {
// source at all, then we'll fail as well. // source at all, then we'll fail as well.
chanSource = newMockChannelSource() chanSource = newMockChannelSource()
chanSource.failQuery = true chanSource.failQuery = true
_, err = FetchStaticChanBackups(chanSource) _, err = FetchStaticChanBackups(chanSource, chanSource)
if err == nil { if err == nil {
t.Fatalf("query should fail") t.Fatalf("query should fail")
} }

View File

@ -6469,7 +6469,7 @@ func (r *rpcServer) ExportChannelBackup(ctx context.Context,
// the database. If this channel has been closed, or the outpoint is // the database. If this channel has been closed, or the outpoint is
// unknown, then we'll return an error // unknown, then we'll return an error
unpackedBackup, err := chanbackup.FetchBackupForChan( unpackedBackup, err := chanbackup.FetchBackupForChan(
chanPoint, r.server.chanStateDB, chanPoint, r.server.chanStateDB, r.server.addrSource,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -6639,7 +6639,7 @@ func (r *rpcServer) ExportAllChannelBackups(ctx context.Context,
// First, we'll attempt to read back ups for ALL currently opened // First, we'll attempt to read back ups for ALL currently opened
// channels from disk. // channels from disk.
allUnpackedBackups, err := chanbackup.FetchStaticChanBackups( allUnpackedBackups, err := chanbackup.FetchStaticChanBackups(
r.server.chanStateDB, r.server.chanStateDB, r.server.addrSource,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to fetch all static chan "+ return nil, fmt.Errorf("unable to fetch all static chan "+
@ -6766,7 +6766,7 @@ func (r *rpcServer) SubscribeChannelBackups(req *lnrpc.ChannelBackupSubscription
// we'll obtains the current set of single channel // we'll obtains the current set of single channel
// backups from disk. // backups from disk.
chanBackups, err := chanbackup.FetchStaticChanBackups( chanBackups, err := chanbackup.FetchStaticChanBackups(
r.server.chanStateDB, r.server.chanStateDB, r.server.addrSource,
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to fetch all "+ return fmt.Errorf("unable to fetch all "+

View File

@ -224,6 +224,8 @@ type server struct {
chanStateDB *channeldb.DB chanStateDB *channeldb.DB
addrSource chanbackup.AddressSource
htlcSwitch *htlcswitch.Switch htlcSwitch *htlcswitch.Switch
interceptableSwitch *htlcswitch.InterceptableSwitch interceptableSwitch *htlcswitch.InterceptableSwitch
@ -433,6 +435,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
cfg: cfg, cfg: cfg,
graphDB: dbs.graphDB.ChannelGraph(), graphDB: dbs.graphDB.ChannelGraph(),
chanStateDB: dbs.chanStateDB, chanStateDB: dbs.chanStateDB,
addrSource: dbs.chanStateDB,
cc: cc, cc: cc,
sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer), sigPool: lnwallet.NewSigPool(cfg.Workers.Sig, cc.Signer),
writePool: writePool, writePool: writePool,
@ -1246,7 +1249,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
addrs: s.chanStateDB, addrs: s.chanStateDB,
} }
backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath) backupFile := chanbackup.NewMultiFile(cfg.BackupFilePath)
startingChans, err := chanbackup.FetchStaticChanBackups(s.chanStateDB) startingChans, err := chanbackup.FetchStaticChanBackups(
s.chanStateDB, s.addrSource,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }