multi: let chan and graph db implement AddrSource

Then use both to construct a multiAddrSource AddrSource and use that
around the code-base.
This commit is contained in:
Elle Mouton
2024-10-22 14:35:37 +02:00
parent 51c2f709e1
commit 2c083bc017
4 changed files with 55 additions and 70 deletions

View File

@@ -34,7 +34,6 @@ import (
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
const (
@@ -1343,64 +1342,24 @@ func (c *ChannelStateDB) RestoreChannelShells(channelShells ...*ChannelShell) er
return nil
}
// AddrsForNode consults the graph and channel database for all addresses known
// to the passed node public key.
// AddrsForNode consults the channel database for all addresses known to the
// passed node public key. The returned boolean indicates if the given node is
// unknown to the channel DB or not.
//
// NOTE: this is part of the AddrSource interface.
func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, error) {
var (
// addrs holds the collection of deduplicated addresses we know
// of for the node.
addrs = make(map[string]net.Addr)
// known keeps track of if any of the backing sources know of
// this node.
known bool
)
// First, query the channel DB for its known addresses.
linkNode, err := d.channelStateDB.linkNodeDB.FetchLinkNode(nodePub)
// Only if the error is something other than ErrNodeNotFound do we
// return it.
switch {
// If we get back a ErrNodeNotFound error, then this just means that the
// channel DB does not know of the error, but we don't error out here
// because we still want to check the graph db.
case err != nil && !errors.Is(err, ErrNodeNotFound):
return false, nil, err
// A nil error means the node is known.
case err == nil:
known = true
for _, addr := range linkNode.Addresses {
addrs[addr.String()] = addr
}
case errors.Is(err, ErrNodeNotFound):
return false, nil, nil
}
// We'll also query the graph for this peer to see if they have any
// addresses that we don't currently have stored within the link node
// database.
pubKey, err := route.NewVertexFromBytes(nodePub.SerializeCompressed())
if err != nil {
return false, nil, err
}
graphNode, err := d.graph.FetchLightningNode(pubKey)
switch {
// We don't consider it an error if the graph is unaware of the node.
case err != nil && !errors.Is(err, graphdb.ErrGraphNodeNotFound):
return false, nil, err
// If we do find the node, we add its addresses to our deduplicated set.
case err == nil:
known = true
for _, addr := range graphNode.Addresses {
addrs[addr.String()] = addr
}
}
// Convert the deduplicated set into a list.
dedupedAddrs := make([]net.Addr, 0, len(addrs))
for _, addr := range addrs {
dedupedAddrs = append(dedupedAddrs, addr)
}
return known, dedupedAddrs, nil
return true, linkNode.Addresses, nil
}
// AbandonChannel attempts to remove the target channel from the open channel

View File

@@ -185,9 +185,10 @@ func TestFetchClosedChannelForID(t *testing.T) {
}
}
// TestAddrsForNode tests the we're able to properly obtain all the addresses
// for a target node.
func TestAddrsForNode(t *testing.T) {
// TestMultiSourceAddrsForNode tests the we're able to properly obtain all the
// addresses for a target node from multiple backends - in this case, the
// channel db and graph db.
func TestMultiSourceAddrsForNode(t *testing.T) {
t.Parallel()
fullDB, err := MakeTestDB(t)
@@ -201,9 +202,7 @@ func TestAddrsForNode(t *testing.T) {
testNode := createTestVertex(t)
require.NoError(t, err, "unable to create test node")
testNode.Addresses = []net.Addr{testAddr}
if err := graph.SetSourceNode(testNode); err != nil {
t.Fatalf("unable to set source node: %v", err)
}
require.NoError(t, graph.SetSourceNode(testNode))
// Next, we'll make a link node with the same pubkey, but with an
// additional address.
@@ -213,28 +212,27 @@ func TestAddrsForNode(t *testing.T) {
fullDB.channelStateDB.linkNodeDB, wire.MainNet, nodePub,
anotherAddr,
)
if err := linkNode.Sync(); err != nil {
t.Fatalf("unable to sync link node: %v", err)
}
require.NoError(t, linkNode.Sync())
// Create a multi-backend address source from the channel db and graph
// db.
addrSource := NewMultiAddrSource(fullDB, graph)
// Now that we've created a link node, as well as a vertex for the
// node, we'll query for all its addresses.
_, nodeAddrs, err := fullDB.AddrsForNode(nodePub)
known, nodeAddrs, err := addrSource.AddrsForNode(nodePub)
require.NoError(t, err, "unable to obtain node addrs")
require.True(t, known)
expectedAddrs := make(map[string]struct{})
expectedAddrs[testAddr.String()] = struct{}{}
expectedAddrs[anotherAddr.String()] = struct{}{}
// Finally, ensure that all the expected addresses are found.
if len(nodeAddrs) != len(expectedAddrs) {
t.Fatalf("expected %v addrs, got %v",
len(expectedAddrs), len(nodeAddrs))
}
require.Len(t, nodeAddrs, len(expectedAddrs))
for _, addr := range nodeAddrs {
if _, ok := expectedAddrs[addr.String()]; !ok {
t.Fatalf("unexpected addr: %v", addr)
}
require.Contains(t, expectedAddrs, addr.String())
}
}