mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-26 13:42:49 +02:00
autopilot: start threading contexts through
The `GraphSource` interface in the `autopilot` package is directly implemented by the `graphdb.KVStore` and so we will eventually thread contexts through to this interface. So in this commit, we start updating the autopilot system to thread contexts through in preparation for passing the context through to any calls made to the GraphSource. Two context.TODOs are added here which will be addressed in follow up commits.
This commit is contained in:
@@ -2,6 +2,7 @@ package autopilot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
prand "math/rand"
|
||||
"net"
|
||||
@@ -126,6 +127,7 @@ func TestPrefAttachmentSelectEmptyGraph(t *testing.T) {
|
||||
// and the funds are appropriately allocated across each peer.
|
||||
func TestPrefAttachmentSelectTwoVertexes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
prand.Seed(time.Now().Unix())
|
||||
|
||||
@@ -156,10 +158,12 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) {
|
||||
// Get the score for all nodes found in the graph at
|
||||
// this point.
|
||||
nodes := make(map[NodeID]struct{})
|
||||
err = graph.ForEachNode(func(n Node) error {
|
||||
nodes[n.PubKey()] = struct{}{}
|
||||
return nil
|
||||
})
|
||||
err = graph.ForEachNode(ctx,
|
||||
func(_ context.Context, n Node) error {
|
||||
nodes[n.PubKey()] = struct{}{}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
require.NoError(t1, err)
|
||||
|
||||
require.Len(t1, nodes, 3)
|
||||
@@ -207,6 +211,7 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) {
|
||||
// allocate all funds to each vertex (up to the max channel size).
|
||||
func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
prand.Seed(time.Now().Unix())
|
||||
|
||||
@@ -245,22 +250,25 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) {
|
||||
numNodes := 0
|
||||
twoChans := false
|
||||
nodes := make(map[NodeID]struct{})
|
||||
err = graph.ForEachNode(func(n Node) error {
|
||||
numNodes++
|
||||
nodes[n.PubKey()] = struct{}{}
|
||||
numChans := 0
|
||||
err := n.ForEachChannel(func(c ChannelEdge) error {
|
||||
numChans++
|
||||
err = graph.ForEachNode(
|
||||
ctx, func(ctx context.Context, n Node) error {
|
||||
numNodes++
|
||||
nodes[n.PubKey()] = struct{}{}
|
||||
numChans := 0
|
||||
err := n.ForEachChannel(ctx,
|
||||
func(_ context.Context, c ChannelEdge) error { //nolint:ll
|
||||
numChans++
|
||||
return nil
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
twoChans = twoChans || (numChans == 2)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
twoChans = twoChans || (numChans == 2)
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t1, err)
|
||||
|
||||
require.EqualValues(t1, 3, numNodes)
|
||||
@@ -313,6 +321,7 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) {
|
||||
// of zero during scoring.
|
||||
func TestPrefAttachmentSelectSkipNodes(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
prand.Seed(time.Now().Unix())
|
||||
|
||||
@@ -335,10 +344,13 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) {
|
||||
require.NoError(t1, err)
|
||||
|
||||
nodes := make(map[NodeID]struct{})
|
||||
err = graph.ForEachNode(func(n Node) error {
|
||||
nodes[n.PubKey()] = struct{}{}
|
||||
return nil
|
||||
})
|
||||
err = graph.ForEachNode(
|
||||
ctx, func(_ context.Context, n Node) error {
|
||||
nodes[n.PubKey()] = struct{}{}
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
require.NoError(t1, err)
|
||||
|
||||
require.Len(t1, nodes, 2)
|
||||
@@ -583,9 +595,11 @@ func newMemChannelGraph() *memChannelGraph {
|
||||
// error, then execution should be terminated.
|
||||
//
|
||||
// NOTE: Part of the autopilot.ChannelGraph interface.
|
||||
func (m *memChannelGraph) ForEachNode(cb func(Node) error) error {
|
||||
func (m *memChannelGraph) ForEachNode(ctx context.Context,
|
||||
cb func(context.Context, Node) error) error {
|
||||
|
||||
for _, node := range m.graph {
|
||||
if err := cb(node); err != nil {
|
||||
if err := cb(ctx, node); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user