diff --git a/autopilot/graph.go b/autopilot/graph.go index b4e415077..c8b54082a 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -1,20 +1,15 @@ package autopilot import ( - "bytes" "encoding/hex" - "errors" "net" "sort" - "sync/atomic" - "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" - "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -36,7 +31,7 @@ var ( // // TODO(roasbeef): move inmpl to main package? type databaseChannelGraph struct { - db *graphdb.ChannelGraph + db GraphSource } // A compile time assertion to ensure databaseChannelGraph meets the @@ -44,8 +39,8 @@ type databaseChannelGraph struct { var _ ChannelGraph = (*databaseChannelGraph)(nil) // ChannelGraphFromDatabase returns an instance of the autopilot.ChannelGraph -// backed by a live, open channeldb instance. -func ChannelGraphFromDatabase(db *graphdb.ChannelGraph) ChannelGraph { +// backed by a GraphSource. +func ChannelGraphFromDatabase(db GraphSource) ChannelGraph { return &databaseChannelGraph{ db: db, } @@ -55,11 +50,7 @@ func ChannelGraphFromDatabase(db *graphdb.ChannelGraph) ChannelGraph { // channeldb.LightningNode. The wrapper method implement the autopilot.Node // interface. type dbNode struct { - db *graphdb.ChannelGraph - - tx kvdb.RTx - - node *models.LightningNode + tx graphdb.NodeRTx } // A compile time assertion to ensure dbNode meets the autopilot.Node @@ -72,7 +63,7 @@ var _ Node = (*dbNode)(nil) // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) PubKey() [33]byte { - return d.node.PubKeyBytes + return d.tx.Node().PubKeyBytes } // Addrs returns a slice of publicly reachable public TCP addresses that the @@ -80,7 +71,7 @@ func (d *dbNode) PubKey() [33]byte { // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) Addrs() []net.Addr { - return d.node.Addresses + return d.tx.Node().Addresses } // ForEachChannel is a higher-order function that will be used to iterate @@ -90,43 +81,35 @@ func (d *dbNode) Addrs() []net.Addr { // // NOTE: Part of the autopilot.Node interface. func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { - return d.db.ForEachNodeChannelTx(d.tx, d.node.PubKeyBytes, - func(tx kvdb.RTx, ei *models.ChannelEdgeInfo, ep, - _ *models.ChannelEdgePolicy) error { + return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep, + _ *models.ChannelEdgePolicy) error { - // Skip channels for which no outgoing edge policy is - // available. - // - // TODO(joostjager): Ideally the case where channels - // have a nil policy should be supported, as autopilot - // is not looking at the policies. For now, it is not - // easily possible to get a reference to the other end - // LightningNode object without retrieving the policy. - if ep == nil { - return nil - } + // Skip channels for which no outgoing edge policy is available. + // + // TODO(joostjager): Ideally the case where channels have a nil + // policy should be supported, as autopilot is not looking at + // the policies. For now, it is not easily possible to get a + // reference to the other end LightningNode object without + // retrieving the policy. + if ep == nil { + return nil + } - node, err := d.db.FetchLightningNodeTx( - tx, ep.ToNode, - ) - if err != nil { - return err - } + node, err := d.tx.FetchNode(ep.ToNode) + if err != nil { + return err + } - edge := ChannelEdge{ - ChanID: lnwire.NewShortChanIDFromInt( - ep.ChannelID, - ), - Capacity: ei.Capacity, - Peer: &dbNode{ - tx: tx, - db: d.db, - node: node, - }, - } + edge := ChannelEdge{ + ChanID: lnwire.NewShortChanIDFromInt(ep.ChannelID), + Capacity: ei.Capacity, + Peer: &dbNode{ + tx: node, + }, + } - return cb(edge) - }) + return cb(edge) + }) } // ForEachNode is a higher-order function that should be called once for each @@ -135,353 +118,25 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { // // NOTE: Part of the autopilot.ChannelGraph interface. func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error { - return d.db.ForEachNode(func(tx kvdb.RTx, - n *models.LightningNode) error { - + return d.db.ForEachNode(func(nodeTx graphdb.NodeRTx) error { // We'll skip over any node that doesn't have any advertised // addresses. As we won't be able to reach them to actually // open any channels. - if len(n.Addresses) == 0 { + if len(nodeTx.Node().Addresses) == 0 { return nil } node := &dbNode{ - db: d.db, - tx: tx, - node: n, + tx: nodeTx, } return cb(node) }) } -// addRandChannel creates a new channel two target nodes. This function is -// meant to aide in the generation of random graphs for use within test cases -// the exercise the autopilot package. -func (d *databaseChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, - capacity btcutil.Amount) (*ChannelEdge, *ChannelEdge, error) { - - fetchNode := func(pub *btcec.PublicKey) (*models.LightningNode, error) { - if pub != nil { - vertex, err := route.NewVertexFromBytes( - pub.SerializeCompressed(), - ) - if err != nil { - return nil, err - } - - dbNode, err := d.db.FetchLightningNode(vertex) - switch { - case errors.Is(err, graphdb.ErrGraphNodeNotFound): - fallthrough - case errors.Is(err, graphdb.ErrGraphNotFound): - graphNode := &models.LightningNode{ - HaveNodeAnnouncement: true, - Addresses: []net.Addr{ - &net.TCPAddr{ - IP: bytes.Repeat([]byte("a"), 16), - }, - }, - Features: lnwire.NewFeatureVector( - nil, lnwire.Features, - ), - AuthSigBytes: testSig.Serialize(), - } - graphNode.AddPubKey(pub) - if err := d.db.AddLightningNode(graphNode); err != nil { - return nil, err - } - case err != nil: - return nil, err - } - - return dbNode, nil - } - - nodeKey, err := randKey() - if err != nil { - return nil, err - } - dbNode := &models.LightningNode{ - HaveNodeAnnouncement: true, - Addresses: []net.Addr{ - &net.TCPAddr{ - IP: bytes.Repeat([]byte("a"), 16), - }, - }, - Features: lnwire.NewFeatureVector( - nil, lnwire.Features, - ), - AuthSigBytes: testSig.Serialize(), - } - dbNode.AddPubKey(nodeKey) - if err := d.db.AddLightningNode(dbNode); err != nil { - return nil, err - } - - return dbNode, nil - } - - vertex1, err := fetchNode(node1) - if err != nil { - return nil, nil, err - } - - vertex2, err := fetchNode(node2) - if err != nil { - return nil, nil, err - } - - var lnNode1, lnNode2 *btcec.PublicKey - if bytes.Compare(vertex1.PubKeyBytes[:], vertex2.PubKeyBytes[:]) == -1 { - lnNode1, _ = vertex1.PubKey() - lnNode2, _ = vertex2.PubKey() - } else { - lnNode1, _ = vertex2.PubKey() - lnNode2, _ = vertex1.PubKey() - } - - chanID := randChanID() - edge := &models.ChannelEdgeInfo{ - ChannelID: chanID.ToUint64(), - Capacity: capacity, - } - edge.AddNodeKeys(lnNode1, lnNode2, lnNode1, lnNode2) - if err := d.db.AddChannelEdge(edge); err != nil { - return nil, nil, err - } - edgePolicy := &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: chanID.ToUint64(), - LastUpdate: time.Now(), - TimeLockDelta: 10, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(capacity), - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - MessageFlags: 1, - ChannelFlags: 0, - } - - if err := d.db.UpdateEdgePolicy(edgePolicy); err != nil { - return nil, nil, err - } - edgePolicy = &models.ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: chanID.ToUint64(), - LastUpdate: time.Now(), - TimeLockDelta: 10, - MinHTLC: 1, - MaxHTLC: lnwire.NewMSatFromSatoshis(capacity), - FeeBaseMSat: 10, - FeeProportionalMillionths: 10000, - MessageFlags: 1, - ChannelFlags: 1, - } - if err := d.db.UpdateEdgePolicy(edgePolicy); err != nil { - return nil, nil, err - } - - return &ChannelEdge{ - ChanID: chanID, - Capacity: capacity, - Peer: &dbNode{ - db: d.db, - node: vertex1, - }, - }, - &ChannelEdge{ - ChanID: chanID, - Capacity: capacity, - Peer: &dbNode{ - db: d.db, - node: vertex2, - }, - }, - nil -} - -func (d *databaseChannelGraph) addRandNode() (*btcec.PublicKey, error) { - nodeKey, err := randKey() - if err != nil { - return nil, err - } - dbNode := &models.LightningNode{ - HaveNodeAnnouncement: true, - Addresses: []net.Addr{ - &net.TCPAddr{ - IP: bytes.Repeat([]byte("a"), 16), - }, - }, - Features: lnwire.NewFeatureVector( - nil, lnwire.Features, - ), - AuthSigBytes: testSig.Serialize(), - } - dbNode.AddPubKey(nodeKey) - if err := d.db.AddLightningNode(dbNode); err != nil { - return nil, err - } - - return nodeKey, nil - -} - -// memChannelGraph is an implementation of the autopilot.ChannelGraph backed by -// an in-memory graph. -type memChannelGraph struct { - graph map[NodeID]*memNode -} - -// A compile time assertion to ensure memChannelGraph meets the -// autopilot.ChannelGraph interface. -var _ ChannelGraph = (*memChannelGraph)(nil) - -// newMemChannelGraph creates a new blank in-memory channel graph -// implementation. -func newMemChannelGraph() *memChannelGraph { - return &memChannelGraph{ - graph: make(map[NodeID]*memNode), - } -} - -// ForEachNode is a higher-order function that should be called once for each -// connected node within the channel graph. If the passed callback returns an -// error, then execution should be terminated. -// -// NOTE: Part of the autopilot.ChannelGraph interface. -func (m memChannelGraph) ForEachNode(cb func(Node) error) error { - for _, node := range m.graph { - if err := cb(node); err != nil { - return err - } - } - - return nil -} - -// randChanID generates a new random channel ID. -func randChanID() lnwire.ShortChannelID { - id := atomic.AddUint64(&chanIDCounter, 1) - return lnwire.NewShortChanIDFromInt(id) -} - -// randKey returns a random public key. -func randKey() (*btcec.PublicKey, error) { - priv, err := btcec.NewPrivateKey() - if err != nil { - return nil, err - } - - return priv.PubKey(), nil -} - -// addRandChannel creates a new channel two target nodes. This function is -// meant to aide in the generation of random graphs for use within test cases -// the exercise the autopilot package. -func (m *memChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, - capacity btcutil.Amount) (*ChannelEdge, *ChannelEdge, error) { - - var ( - vertex1, vertex2 *memNode - ok bool - ) - - if node1 != nil { - vertex1, ok = m.graph[NewNodeID(node1)] - if !ok { - vertex1 = &memNode{ - pub: node1, - addrs: []net.Addr{ - &net.TCPAddr{ - IP: bytes.Repeat([]byte("a"), 16), - }, - }, - } - } - } else { - newPub, err := randKey() - if err != nil { - return nil, nil, err - } - vertex1 = &memNode{ - pub: newPub, - addrs: []net.Addr{ - &net.TCPAddr{ - IP: bytes.Repeat([]byte("a"), 16), - }, - }, - } - } - - if node2 != nil { - vertex2, ok = m.graph[NewNodeID(node2)] - if !ok { - vertex2 = &memNode{ - pub: node2, - addrs: []net.Addr{ - &net.TCPAddr{ - IP: bytes.Repeat([]byte("a"), 16), - }, - }, - } - } - } else { - newPub, err := randKey() - if err != nil { - return nil, nil, err - } - vertex2 = &memNode{ - pub: newPub, - addrs: []net.Addr{ - &net.TCPAddr{ - IP: bytes.Repeat([]byte("a"), 16), - }, - }, - } - } - - edge1 := ChannelEdge{ - ChanID: randChanID(), - Capacity: capacity, - Peer: vertex2, - } - vertex1.chans = append(vertex1.chans, edge1) - - edge2 := ChannelEdge{ - ChanID: randChanID(), - Capacity: capacity, - Peer: vertex1, - } - vertex2.chans = append(vertex2.chans, edge2) - - m.graph[NewNodeID(vertex1.pub)] = vertex1 - m.graph[NewNodeID(vertex2.pub)] = vertex2 - - return &edge1, &edge2, nil -} - -func (m *memChannelGraph) addRandNode() (*btcec.PublicKey, error) { - newPub, err := randKey() - if err != nil { - return nil, err - } - vertex := &memNode{ - pub: newPub, - addrs: []net.Addr{ - &net.TCPAddr{ - IP: bytes.Repeat([]byte("a"), 16), - }, - }, - } - m.graph[NewNodeID(newPub)] = vertex - - return newPub, nil -} - // databaseChannelGraphCached wraps a channeldb.ChannelGraph instance with the // necessary API to properly implement the autopilot.ChannelGraph interface. type databaseChannelGraphCached struct { - db *graphdb.ChannelGraph + db GraphSource } // A compile time assertion to ensure databaseChannelGraphCached meets the @@ -490,7 +145,7 @@ var _ ChannelGraph = (*databaseChannelGraphCached)(nil) // ChannelGraphFromCachedDatabase returns an instance of the // autopilot.ChannelGraph backed by a live, open channeldb instance. -func ChannelGraphFromCachedDatabase(db *graphdb.ChannelGraph) ChannelGraph { +func ChannelGraphFromCachedDatabase(db GraphSource) ChannelGraph { return &databaseChannelGraphCached{ db: db, } diff --git a/autopilot/interface.go b/autopilot/interface.go index 671d34332..35182a760 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -6,7 +6,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" + graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" ) // DefaultConfTarget is the default confirmation target for autopilot channels. @@ -216,3 +218,20 @@ type ChannelController interface { // TODO(roasbeef): add force option? CloseChannel(chanPoint *wire.OutPoint) error } + +// GraphSource represents read access to the channel graph. +type GraphSource interface { + // ForEachNode iterates through all the stored vertices/nodes in the + // graph, executing the passed callback with each node encountered. If + // the callback returns an error, then the transaction is aborted and + // the iteration stops early. Any operations performed on the NodeTx + // passed to the call-back are executed under the same read transaction. + ForEachNode(func(graphdb.NodeRTx) error) error + + // ForEachNodeCached is similar to ForEachNode, but it utilizes the + // channel graph cache if one is available. It is less consistent than + // ForEachNode since any further calls are made across multiple + // transactions. + ForEachNodeCached(cb func(node route.Vertex, + chans map[uint64]*graphdb.DirectedChannel) error) error +} diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index ab52c55f6..d30174031 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -2,14 +2,20 @@ package autopilot import ( "bytes" + "errors" prand "math/rand" + "net" + "sync/atomic" "testing" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" ) @@ -24,6 +30,11 @@ type testGraph interface { addRandNode() (*btcec.PublicKey, error) } +type testDBGraph struct { + db *graphdb.ChannelGraph + databaseChannelGraph +} + func newDiskChanGraph(t *testing.T) (testGraph, error) { backend, err := kvdb.GetBoltBackend(&kvdb.BoltBackendConfig{ DBPath: t.TempDir(), @@ -38,12 +49,15 @@ func newDiskChanGraph(t *testing.T) (testGraph, error) { graphDB, err := graphdb.NewChannelGraph(backend) require.NoError(t, err) - return &databaseChannelGraph{ + return &testDBGraph{ db: graphDB, + databaseChannelGraph: databaseChannelGraph{ + db: graphDB, + }, }, nil } -var _ testGraph = (*databaseChannelGraph)(nil) +var _ testGraph = (*testDBGraph)(nil) func newMemChanGraph(_ *testing.T) (testGraph, error) { return newMemChannelGraph(), nil @@ -368,3 +382,357 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) { } } } + +// addRandChannel creates a new channel two target nodes. This function is +// meant to aide in the generation of random graphs for use within test cases +// the exercise the autopilot package. +func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey, + capacity btcutil.Amount) (*ChannelEdge, *ChannelEdge, error) { + + fetchNode := func(pub *btcec.PublicKey) (*models.LightningNode, error) { + if pub != nil { + vertex, err := route.NewVertexFromBytes( + pub.SerializeCompressed(), + ) + if err != nil { + return nil, err + } + + dbNode, err := d.db.FetchLightningNode(vertex) + switch { + case errors.Is(err, graphdb.ErrGraphNodeNotFound): + fallthrough + case errors.Is(err, graphdb.ErrGraphNotFound): + graphNode := &models.LightningNode{ + HaveNodeAnnouncement: true, + Addresses: []net.Addr{&net.TCPAddr{ + IP: bytes.Repeat( + []byte("a"), 16, + ), + }}, + Features: lnwire.NewFeatureVector( + nil, lnwire.Features, + ), + AuthSigBytes: testSig.Serialize(), + } + graphNode.AddPubKey(pub) + err := d.db.AddLightningNode(graphNode) + if err != nil { + return nil, err + } + case err != nil: + return nil, err + } + + return dbNode, nil + } + + nodeKey, err := randKey() + if err != nil { + return nil, err + } + dbNode := &models.LightningNode{ + HaveNodeAnnouncement: true, + Addresses: []net.Addr{ + &net.TCPAddr{ + IP: bytes.Repeat([]byte("a"), 16), + }, + }, + Features: lnwire.NewFeatureVector( + nil, lnwire.Features, + ), + AuthSigBytes: testSig.Serialize(), + } + dbNode.AddPubKey(nodeKey) + if err := d.db.AddLightningNode(dbNode); err != nil { + return nil, err + } + + return dbNode, nil + } + + vertex1, err := fetchNode(node1) + if err != nil { + return nil, nil, err + } + + vertex2, err := fetchNode(node2) + if err != nil { + return nil, nil, err + } + + var lnNode1, lnNode2 *btcec.PublicKey + if bytes.Compare(vertex1.PubKeyBytes[:], vertex2.PubKeyBytes[:]) == -1 { + lnNode1, _ = vertex1.PubKey() + lnNode2, _ = vertex2.PubKey() + } else { + lnNode1, _ = vertex2.PubKey() + lnNode2, _ = vertex1.PubKey() + } + + chanID := randChanID() + edge := &models.ChannelEdgeInfo{ + ChannelID: chanID.ToUint64(), + Capacity: capacity, + } + edge.AddNodeKeys(lnNode1, lnNode2, lnNode1, lnNode2) + if err := d.db.AddChannelEdge(edge); err != nil { + return nil, nil, err + } + edgePolicy := &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: chanID.ToUint64(), + LastUpdate: time.Now(), + TimeLockDelta: 10, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(capacity), + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + MessageFlags: 1, + ChannelFlags: 0, + } + + if err := d.db.UpdateEdgePolicy(edgePolicy); err != nil { + return nil, nil, err + } + edgePolicy = &models.ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: chanID.ToUint64(), + LastUpdate: time.Now(), + TimeLockDelta: 10, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(capacity), + FeeBaseMSat: 10, + FeeProportionalMillionths: 10000, + MessageFlags: 1, + ChannelFlags: 1, + } + if err := d.db.UpdateEdgePolicy(edgePolicy); err != nil { + return nil, nil, err + } + + return &ChannelEdge{ + ChanID: chanID, + Capacity: capacity, + Peer: &dbNode{tx: &testNodeTx{ + db: d, + node: vertex1, + }}, + }, + &ChannelEdge{ + ChanID: chanID, + Capacity: capacity, + Peer: &dbNode{tx: &testNodeTx{ + db: d, + node: vertex2, + }}, + }, + nil +} + +func (d *testDBGraph) addRandNode() (*btcec.PublicKey, error) { + nodeKey, err := randKey() + if err != nil { + return nil, err + } + dbNode := &models.LightningNode{ + HaveNodeAnnouncement: true, + Addresses: []net.Addr{ + &net.TCPAddr{ + IP: bytes.Repeat([]byte("a"), 16), + }, + }, + Features: lnwire.NewFeatureVector( + nil, lnwire.Features, + ), + AuthSigBytes: testSig.Serialize(), + } + dbNode.AddPubKey(nodeKey) + if err := d.db.AddLightningNode(dbNode); err != nil { + return nil, err + } + + return nodeKey, nil +} + +// memChannelGraph is an implementation of the autopilot.ChannelGraph backed by +// an in-memory graph. +type memChannelGraph struct { + graph map[NodeID]*memNode +} + +// A compile time assertion to ensure memChannelGraph meets the +// autopilot.ChannelGraph interface. +var _ ChannelGraph = (*memChannelGraph)(nil) + +// newMemChannelGraph creates a new blank in-memory channel graph +// implementation. +func newMemChannelGraph() *memChannelGraph { + return &memChannelGraph{ + graph: make(map[NodeID]*memNode), + } +} + +// ForEachNode is a higher-order function that should be called once for each +// connected node within the channel graph. If the passed callback returns an +// error, then execution should be terminated. +// +// NOTE: Part of the autopilot.ChannelGraph interface. +func (m memChannelGraph) ForEachNode(cb func(Node) error) error { + for _, node := range m.graph { + if err := cb(node); err != nil { + return err + } + } + + return nil +} + +// randChanID generates a new random channel ID. +func randChanID() lnwire.ShortChannelID { + id := atomic.AddUint64(&chanIDCounter, 1) + return lnwire.NewShortChanIDFromInt(id) +} + +// randKey returns a random public key. +func randKey() (*btcec.PublicKey, error) { + priv, err := btcec.NewPrivateKey() + if err != nil { + return nil, err + } + + return priv.PubKey(), nil +} + +// addRandChannel creates a new channel two target nodes. This function is +// meant to aide in the generation of random graphs for use within test cases +// the exercise the autopilot package. +func (m *memChannelGraph) addRandChannel(node1, node2 *btcec.PublicKey, + capacity btcutil.Amount) (*ChannelEdge, *ChannelEdge, error) { + + var ( + vertex1, vertex2 *memNode + ok bool + ) + + if node1 != nil { + vertex1, ok = m.graph[NewNodeID(node1)] + if !ok { + vertex1 = &memNode{ + pub: node1, + addrs: []net.Addr{&net.TCPAddr{ + IP: bytes.Repeat([]byte("a"), 16), + }}, + } + } + } else { + newPub, err := randKey() + if err != nil { + return nil, nil, err + } + vertex1 = &memNode{ + pub: newPub, + addrs: []net.Addr{ + &net.TCPAddr{ + IP: bytes.Repeat([]byte("a"), 16), + }, + }, + } + } + + if node2 != nil { + vertex2, ok = m.graph[NewNodeID(node2)] + if !ok { + vertex2 = &memNode{ + pub: node2, + addrs: []net.Addr{&net.TCPAddr{ + IP: bytes.Repeat([]byte("a"), 16), + }}, + } + } + } else { + newPub, err := randKey() + if err != nil { + return nil, nil, err + } + vertex2 = &memNode{ + pub: newPub, + addrs: []net.Addr{ + &net.TCPAddr{ + IP: bytes.Repeat([]byte("a"), 16), + }, + }, + } + } + + edge1 := ChannelEdge{ + ChanID: randChanID(), + Capacity: capacity, + Peer: vertex2, + } + vertex1.chans = append(vertex1.chans, edge1) + + edge2 := ChannelEdge{ + ChanID: randChanID(), + Capacity: capacity, + Peer: vertex1, + } + vertex2.chans = append(vertex2.chans, edge2) + + m.graph[NewNodeID(vertex1.pub)] = vertex1 + m.graph[NewNodeID(vertex2.pub)] = vertex2 + + return &edge1, &edge2, nil +} + +func (m *memChannelGraph) addRandNode() (*btcec.PublicKey, error) { + newPub, err := randKey() + if err != nil { + return nil, err + } + vertex := &memNode{ + pub: newPub, + addrs: []net.Addr{ + &net.TCPAddr{ + IP: bytes.Repeat([]byte("a"), 16), + }, + }, + } + m.graph[NewNodeID(newPub)] = vertex + + return newPub, nil +} + +type testNodeTx struct { + db *testDBGraph + node *models.LightningNode +} + +func (t *testNodeTx) Node() *models.LightningNode { + return t.node +} + +func (t *testNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { + + return t.db.db.ForEachNodeChannel(t.node.PubKeyBytes, func(_ kvdb.RTx, + edge *models.ChannelEdgeInfo, policy1, + policy2 *models.ChannelEdgePolicy) error { + + return f(edge, policy1, policy2) + }) +} + +func (t *testNodeTx) FetchNode(pub route.Vertex) (graphdb.NodeRTx, error) { + node, err := t.db.db.FetchLightningNode(pub) + if err != nil { + return nil, err + } + + return &testNodeTx{ + db: t.db, + node: node, + }, nil +} + +var _ graphdb.NodeRTx = (*testNodeTx)(nil) diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index c4cf4685e..e70492d6f 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -240,6 +240,9 @@ config option](https://github.com/lightningnetwork/lnd/pull/9182) and introduce a new option `channel-max-fee-exposure` which is unambiguous in its description. The underlying functionality between those two options remain the same. +* [Abstraction of graph](https://github.com/lightningnetwork/lnd/pull/9480) + access for autopilot. + * [Golang was updated to `v1.22.11`](https://github.com/lightningnetwork/lnd/pull/9462). diff --git a/graph/builder.go b/graph/builder.go index 3b18a30a3..db73c9272 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -1553,18 +1553,6 @@ func (b *Builder) FetchLightningNode( return b.cfg.Graph.FetchLightningNode(node) } -// ForEachNode is used to iterate over every node in router topology. -// -// NOTE: This method is part of the ChannelGraphSource interface. -func (b *Builder) ForEachNode( - cb func(*models.LightningNode) error) error { - - return b.cfg.Graph.ForEachNode( - func(_ kvdb.RTx, n *models.LightningNode) error { - return cb(n) - }) -} - // ForAllOutgoingChannels is used to iterate over all outgoing channels owned by // the router. // diff --git a/graph/db/graph.go b/graph/db/graph.go index d5a876a79..06bd0257d 100644 --- a/graph/db/graph.go +++ b/graph/db/graph.go @@ -588,9 +588,9 @@ func (c *ChannelGraph) FetchNodeFeatures( } } -// ForEachNodeCached is similar to ForEachNode, but it utilizes the channel +// ForEachNodeCached is similar to forEachNode, but it utilizes the channel // graph cache instead. Note that this doesn't return all the information the -// regular ForEachNode method does. +// regular forEachNode method does. // // NOTE: The callback contents MUST not be modified. func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, @@ -604,7 +604,7 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, // We'll iterate over each node, then the set of channels for each // node, and construct a similar callback functiopn signature as the // main funcotin expects. - return c.ForEachNode(func(tx kvdb.RTx, + return c.forEachNode(func(tx kvdb.RTx, node *models.LightningNode) error { channels := make(map[uint64]*DirectedChannel) @@ -716,11 +716,25 @@ func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { // ForEachNode iterates through all the stored vertices/nodes in the graph, // executing the passed callback with each node encountered. If the callback // returns an error, then the transaction is aborted and the iteration stops +// early. Any operations performed on the NodeTx passed to the call-back are +// executed under the same read transaction and so, methods on the NodeTx object +// _MUST_ only be called from within the call-back. +func (c *ChannelGraph) ForEachNode(cb func(tx NodeRTx) error) error { + return c.forEachNode(func(tx kvdb.RTx, + node *models.LightningNode) error { + + return cb(newChanGraphNodeTx(tx, c, node)) + }) +} + +// forEachNode iterates through all the stored vertices/nodes in the graph, +// executing the passed callback with each node encountered. If the callback +// returns an error, then the transaction is aborted and the iteration stops // early. // // TODO(roasbeef): add iterator interface to allow for memory efficient graph // traversal when graph gets mega -func (c *ChannelGraph) ForEachNode( +func (c *ChannelGraph) forEachNode( cb func(kvdb.RTx, *models.LightningNode) error) error { traversal := func(tx kvdb.RTx) error { @@ -4717,6 +4731,65 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy, return edge, nil } +// chanGraphNodeTx is an implementation of the NodeRTx interface backed by the +// ChannelGraph and a kvdb.RTx. +type chanGraphNodeTx struct { + tx kvdb.RTx + db *ChannelGraph + node *models.LightningNode +} + +// A compile-time constraint to ensure chanGraphNodeTx implements the NodeRTx +// interface. +var _ NodeRTx = (*chanGraphNodeTx)(nil) + +func newChanGraphNodeTx(tx kvdb.RTx, db *ChannelGraph, + node *models.LightningNode) *chanGraphNodeTx { + + return &chanGraphNodeTx{ + tx: tx, + db: db, + node: node, + } +} + +// Node returns the raw information of the node. +// +// NOTE: This is a part of the NodeRTx interface. +func (c *chanGraphNodeTx) Node() *models.LightningNode { + return c.node +} + +// FetchNode fetches the node with the given pub key under the same transaction +// used to fetch the current node. The returned node is also a NodeRTx and any +// operations on that NodeRTx will also be done under the same transaction. +// +// NOTE: This is a part of the NodeRTx interface. +func (c *chanGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) { + node, err := c.db.FetchLightningNodeTx(c.tx, nodePub) + if err != nil { + return nil, err + } + + return newChanGraphNodeTx(c.tx, c.db, node), nil +} + +// ForEachChannel can be used to iterate over the node's channels under +// the same transaction used to fetch the node. +// +// NOTE: This is a part of the NodeRTx interface. +func (c *chanGraphNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { + + return c.db.ForEachNodeChannelTx(c.tx, c.node.PubKeyBytes, + func(_ kvdb.RTx, info *models.ChannelEdgeInfo, policy1, + policy2 *models.ChannelEdgePolicy) error { + + return f(info, policy1, policy2) + }, + ) +} + // MakeTestGraph creates a new instance of the ChannelGraph for testing // purposes. func MakeTestGraph(t testing.TB, modifiers ...OptionModifier) (*ChannelGraph, diff --git a/graph/db/graph_cache_test.go b/graph/db/graph_cache_test.go index 3f140c4c5..69abb2597 100644 --- a/graph/db/graph_cache_test.go +++ b/graph/db/graph_cache_test.go @@ -121,7 +121,7 @@ func TestGraphCacheAddNode(t *testing.T) { assertCachedPolicyEqual(t, outPolicy1, toChannels[0].InPolicy) // Now that we've inserted two nodes into the graph, check that - // we'll recover the same set of channels during ForEachNode. + // we'll recover the same set of channels during forEachNode. nodes := make(map[route.Vertex]struct{}) chans := make(map[uint64]struct{}) _ = cache.ForEachNode(func(node route.Vertex, diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index d048cdafd..3b3454b6b 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1092,7 +1092,7 @@ func TestGraphTraversalCacheable(t *testing.T) { // Create a map of all nodes with the iteration we know works (because // it is tested in another test). nodeMap := make(map[route.Vertex]struct{}) - err = graph.ForEachNode( + err = graph.forEachNode( func(tx kvdb.RTx, n *models.LightningNode) error { nodeMap[n.PubKeyBytes] = struct{}{} @@ -1217,7 +1217,7 @@ func fillTestGraph(t require.TestingT, graph *ChannelGraph, numNodes, // Iterate over each node as returned by the graph, if all nodes are // reached, then the map created above should be empty. - err := graph.ForEachNode( + err := graph.forEachNode( func(_ kvdb.RTx, node *models.LightningNode) error { delete(nodeIndex, node.Alias) return nil @@ -1329,7 +1329,7 @@ func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { numNodes := 0 - err := graph.ForEachNode( + err := graph.forEachNode( func(_ kvdb.RTx, _ *models.LightningNode) error { numNodes++ return nil diff --git a/graph/db/interfaces.go b/graph/db/interfaces.go new file mode 100644 index 000000000..f44a9ff8b --- /dev/null +++ b/graph/db/interfaces.go @@ -0,0 +1,25 @@ +package graphdb + +import ( + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/routing/route" +) + +// NodeRTx represents transaction object with an underlying node associated that +// can be used to make further queries to the graph under the same transaction. +// This is useful for consistency during graph traversal and queries. +type NodeRTx interface { + // Node returns the raw information of the node. + Node() *models.LightningNode + + // ForEachChannel can be used to iterate over the node's channels under + // the same transaction used to fetch the node. + ForEachChannel(func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error + + // FetchNode fetches the node with the given pub key under the same + // transaction used to fetch the current node. The returned node is also + // a NodeRTx and any operations on that NodeRTx will also be done under + // the same transaction. + FetchNode(node route.Vertex) (NodeRTx, error) +} diff --git a/graph/interfaces.go b/graph/interfaces.go index eb7f56603..10ca200f3 100644 --- a/graph/interfaces.go +++ b/graph/interfaces.go @@ -85,9 +85,6 @@ type ChannelGraphSource interface { // public key. channeldb.ErrGraphNodeNotFound is returned if the node // doesn't exist within the graph. FetchLightningNode(route.Vertex) (*models.LightningNode, error) - - // ForEachNode is used to iterate over every node in the known graph. - ForEachNode(func(node *models.LightningNode) error) error } // DB is an interface describing a persisted Lightning Network graph. @@ -241,12 +238,6 @@ type DB interface { FetchLightningNode(nodePub route.Vertex) (*models.LightningNode, error) - // ForEachNode iterates through all the stored vertices/nodes in the - // graph, executing the passed callback with each node encountered. If - // the callback returns an error, then the transaction is aborted and - // the iteration stops early. - ForEachNode(cb func(kvdb.RTx, *models.LightningNode) error) error - // ForEachNodeChannel iterates through all channels of the given node, // executing the passed callback with an edge info structure and the // policies of each end of the channel. The first edge policy is the diff --git a/rpcserver.go b/rpcserver.go index 72e2fa4af..2e17eb75c 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6533,10 +6533,8 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // First iterate through all the known nodes (connected or unconnected // within the graph), collating their current state into the RPC // response. - err := graph.ForEachNode(func(_ kvdb.RTx, - node *models.LightningNode) error { - - lnNode := marshalNode(node) + err := graph.ForEachNode(func(nodeTx graphdb.NodeRTx) error { + lnNode := marshalNode(nodeTx.Node()) resp.Nodes = append(resp.Nodes, lnNode)