autopilot: start threading contexts through

The `GraphSource` interface in the `autopilot` package is directly
implemented by the `graphdb.KVStore` and so we will eventually thread
contexts through to this interface. So in this commit, we start updating
the autopilot system to thread contexts through in preparation for
passing the context through to any calls made to the GraphSource.

Two context.TODOs are added here which will be addressed in follow up
commits.
This commit is contained in:
Elle Mouton
2025-04-09 06:25:29 +02:00
committed by ziggie
parent ca52834795
commit cd4a59071d
13 changed files with 146 additions and 71 deletions

View File

@@ -1,6 +1,7 @@
package autopilot
import (
"context"
"encoding/hex"
"net"
"sort"
@@ -80,7 +81,9 @@ func (d *dbNode) Addrs() []net.Addr {
// describes the active channel.
//
// NOTE: Part of the autopilot.Node interface.
func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
func (d *dbNode) ForEachChannel(ctx context.Context,
cb func(context.Context, ChannelEdge) error) error {
return d.tx.ForEachChannel(func(ei *models.ChannelEdgeInfo, ep,
_ *models.ChannelEdgePolicy) error {
@@ -108,7 +111,7 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
},
}
return cb(edge)
return cb(ctx, edge)
})
}
@@ -117,7 +120,9 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error {
// error, then execution should be terminated.
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error {
func (d *databaseChannelGraph) ForEachNode(ctx context.Context,
cb func(context.Context, Node) error) error {
return d.db.ForEachNode(func(nodeTx graphdb.NodeRTx) error {
// We'll skip over any node that doesn't have any advertised
// addresses. As we won't be able to reach them to actually
@@ -129,7 +134,8 @@ func (d *databaseChannelGraph) ForEachNode(cb func(Node) error) error {
node := &dbNode{
tx: nodeTx,
}
return cb(node)
return cb(ctx, node)
})
}
@@ -185,7 +191,9 @@ func (nc dbNodeCached) Addrs() []net.Addr {
// describes the active channel.
//
// NOTE: Part of the autopilot.Node interface.
func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error {
func (nc dbNodeCached) ForEachChannel(ctx context.Context,
cb func(context.Context, ChannelEdge) error) error {
for cid, channel := range nc.channels {
edge := ChannelEdge{
ChanID: lnwire.NewShortChanIDFromInt(cid),
@@ -195,7 +203,7 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error {
},
}
if err := cb(edge); err != nil {
if err := cb(ctx, edge); err != nil {
return err
}
}
@@ -208,7 +216,9 @@ func (nc dbNodeCached) ForEachChannel(cb func(ChannelEdge) error) error {
// error, then execution should be terminated.
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error {
func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context,
cb func(context.Context, Node) error) error {
return dc.db.ForEachNodeCached(func(n route.Vertex,
channels map[uint64]*graphdb.DirectedChannel) error {
@@ -217,7 +227,8 @@ func (dc *databaseChannelGraphCached) ForEachNode(cb func(Node) error) error {
node: n,
channels: channels,
}
return cb(node)
return cb(ctx, node)
}
return nil
})
@@ -262,9 +273,11 @@ func (m memNode) Addrs() []net.Addr {
// describes the active channel.
//
// NOTE: Part of the autopilot.Node interface.
func (m memNode) ForEachChannel(cb func(ChannelEdge) error) error {
func (m memNode) ForEachChannel(ctx context.Context,
cb func(context.Context, ChannelEdge) error) error {
for _, channel := range m.chans {
if err := cb(channel); err != nil {
if err := cb(ctx, channel); err != nil {
return err
}
}