graph+autopilot: let autopilot use new graph ForEachNode method

Which passes a NodeRTx to the call-back instead of a `kvdb.RTx`.
This commit is contained in:
Elle Mouton
2025-02-05 12:18:11 +02:00
parent 14cedef58e
commit 9b86ee53db
6 changed files with 95 additions and 66 deletions

View File

@@ -10,7 +10,6 @@ import (
"github.com/btcsuite/btcd/btcutil"
graphdb "github.com/lightningnetwork/lnd/graph/db"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
@@ -51,11 +50,7 @@ func ChannelGraphFromDatabase(db *graphdb.ChannelGraph) ChannelGraph {
// channeldb.LightningNode. The wrapper method implement the autopilot.Node
// interface.
type dbNode struct {
db *graphdb.ChannelGraph
tx kvdb.RTx
node *models.LightningNode
tx graphdb.NodeRTx
}
// A compile time assertion to ensure dbNode meets the autopilot.Node
@@ -68,7 +63,7 @@ var _ Node = (*dbNode)(nil)
//
// NOTE: Part of the autopilot.Node interface.
func (d *dbNode) PubKey() [33]byte {
return d.node.PubKeyBytes
return d.tx.Node().PubKeyBytes
}
// Addrs returns a slice of publicly reachable public TCP addresses that the
@@ -76,7 +71,7 @@ func (d *dbNode) PubKey() [33]byte {
//
// NOTE: Part of the autopilot.Node interface.
func (d *dbNode) Addrs() []net.Addr {
return d.node.Addresses
return d.tx.Node().Addresses
}
// ForEachChannel is a higher-order function that will be used to iterate
@@ -86,43 +81,35 @@ func (d *dbNode) Addrs() []net.Addr {
//
// NOTE: Part of the autopilot.Node interface.
func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
return d.db.ForEachNodeChannelTx(d.tx, d.node.PubKeyBytes,
func(tx kvdb.RTx, ei *models.ChannelEdgeInfo, ep,
_ *models.ChannelEdgePolicy) error {
return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep,
_ *models.ChannelEdgePolicy) error {
// Skip channels for which no outgoing edge policy is
// available.
//
// TODO(joostjager): Ideally the case where channels
// have a nil policy should be supported, as autopilot
// is not looking at the policies. For now, it is not
// easily possible to get a reference to the other end
// LightningNode object without retrieving the policy.
if ep == nil {
return nil
}
// Skip channels for which no outgoing edge policy is available.
//
// TODO(joostjager): Ideally the case where channels have a nil
// policy should be supported, as autopilot is not looking at
// the policies. For now, it is not easily possible to get a
// reference to the other end LightningNode object without
// retrieving the policy.
if ep == nil {
return nil
}
node, err := d.db.FetchLightningNodeTx(
tx, ep.ToNode,
)
if err != nil {
return err
}
node, err := d.tx.FetchNode(ep.ToNode)
if err != nil {
return err
}
edge := ChannelEdge{
ChanID: lnwire.NewShortChanIDFromInt(
ep.ChannelID,
),
Capacity: ei.Capacity,
Peer: &dbNode{
tx: tx,
db: d.db,
node: node,
},
}
edge := ChannelEdge{
ChanID: lnwire.NewShortChanIDFromInt(ep.ChannelID),
Capacity: ei.Capacity,
Peer: &dbNode{
tx: node,
},
}
return cb(edge)
})
return cb(edge)
})
}
// ForEachNode is a higher-order function that should be called once for each
@@ -131,20 +118,16 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error {
return d.db.ForEachNode(func(tx kvdb.RTx,
n *models.LightningNode) error {
return d.db.ForEachNode(func(nodeTx graphdb.NodeRTx) error {
// We'll skip over any node that doesn't have any advertised
// addresses. As we won't be able to reach them to actually
// open any channels.
if len(n.Addresses) == 0 {
if len(nodeTx.Node().Addresses) == 0 {
return nil
}
node := &dbNode{
db: d.db,
tx: tx,
node: n,
tx: nodeTx,
}
return cb(node)
})

View File

@@ -514,18 +514,18 @@ func (d *testDBGraph) addRandChannel(node1, node2 *btcec.PublicKey,
return &ChannelEdge{
ChanID: chanID,
Capacity: capacity,
Peer: &dbNode{
db: d.db,
Peer: &dbNode{tx: &testNodeTx{
db: d,
node: vertex1,
},
}},
},
&ChannelEdge{
ChanID: chanID,
Capacity: capacity,
Peer: &dbNode{
db: d.db,
Peer: &dbNode{tx: &testNodeTx{
db: d,
node: vertex2,
},
}},
},
nil
}
@@ -702,3 +702,37 @@ func (m *memChannelGraph) addRandNode() (*btcec.PublicKey, error) {
return newPub, nil
}
type testNodeTx struct {
db *testDBGraph
node *models.LightningNode
}
func (t *testNodeTx) Node() *models.LightningNode {
return t.node
}
func (t *testNodeTx) ForEachChannel(f func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
return t.db.db.ForEachNodeChannel(t.node.PubKeyBytes, func(_ kvdb.RTx,
edge *models.ChannelEdgeInfo, policy1,
policy2 *models.ChannelEdgePolicy) error {
return f(edge, policy1, policy2)
})
}
func (t *testNodeTx) FetchNode(pub route.Vertex) (graphdb.NodeRTx, error) {
node, err := t.db.db.FetchLightningNode(pub)
if err != nil {
return nil, err
}
return &testNodeTx{
db: t.db,
node: node,
}, nil
}
var _ graphdb.NodeRTx = (*testNodeTx)(nil)