multi: pass reset to ForEachNodeCached

This commit is contained in:
Elle Mouton
2025-07-11 10:44:06 +02:00
parent e5fbca8299
commit c32bf642d2
14 changed files with 61 additions and 25 deletions

View File

@@ -640,6 +640,9 @@ func (a *Agent) openChans(ctx context.Context, availableFunds btcutil.Amount,
nodes[nID] = struct{}{}
return nil
}, func() {
nodes = nil
addresses = nil
}); err != nil {
return fmt.Errorf("unable to get graph nodes: %w", err)
}

View File

@@ -121,7 +121,7 @@ func (d *dbNode) ForEachChannel(ctx context.Context,
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (d *databaseChannelGraph) ForEachNode(ctx context.Context,
cb func(context.Context, Node) error) error {
cb func(context.Context, Node) error, _ func()) error {
return d.db.ForEachNode(ctx, func(nodeTx graphdb.NodeRTx) error {
// We'll skip over any node that doesn't have any advertised
@@ -217,7 +217,7 @@ func (nc dbNodeCached) ForEachChannel(ctx context.Context,
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context,
cb func(context.Context, Node) error) error {
cb func(context.Context, Node) error, reset func()) error {
return dc.db.ForEachNodeCached(ctx, func(n route.Vertex,
channels map[uint64]*graphdb.DirectedChannel) error {
@@ -231,7 +231,7 @@ func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context,
return cb(ctx, node)
}
return nil
})
}, reset)
}
// memNode is a purely in-memory implementation of the autopilot.Node

View File

@@ -85,7 +85,8 @@ type ChannelGraph interface {
// ForEachNode is a higher-order function that should be called once
// for each connected node within the channel graph. If the passed
// callback returns an error, then execution should be terminated.
ForEachNode(context.Context, func(context.Context, Node) error) error
ForEachNode(context.Context, func(context.Context, Node) error,
func()) error
}
// NodeScore is a tuple mapping a NodeID to a score indicating the preference
@@ -235,5 +236,6 @@ type GraphSource interface {
// ForEachNode since any further calls are made across multiple
// transactions.
ForEachNodeCached(ctx context.Context, cb func(node route.Vertex,
chans map[uint64]*graphdb.DirectedChannel) error) error
chans map[uint64]*graphdb.DirectedChannel) error,
reset func()) error
}

View File

@@ -105,6 +105,9 @@ func (p *PrefAttachment) NodeScores(ctx context.Context, g ChannelGraph,
}
return nil
}, func() {
allChans = nil
clear(seenChans)
}); err != nil {
return nil, err
}
@@ -162,6 +165,9 @@ func (p *PrefAttachment) NodeScores(ctx context.Context, g ChannelGraph,
log.Tracef("Counted %v channels for node %x", nodeChans, nID[:])
return nil
}, func() {
maxChans = 0
clear(nodeChanNum)
}); err != nil {
return nil, err
}

View File

@@ -147,10 +147,13 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) {
// Get the score for all nodes found in the graph at
// this point.
nodes := make(map[NodeID]struct{})
err = graph.ForEachNode(ctx,
err = graph.ForEachNode(
ctx,
func(_ context.Context, n Node) error {
nodes[n.PubKey()] = struct{}{}
return nil
}, func() {
clear(nodes)
},
)
require.NoError(t1, err)
@@ -257,7 +260,12 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) {
twoChans = twoChans || (numChans == 2)
return nil
})
}, func() {
numNodes = 0
twoChans = false
clear(nodes)
},
)
require.NoError(t1, err)
require.EqualValues(t1, 3, numNodes)
@@ -338,7 +346,7 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) {
nodes[n.PubKey()] = struct{}{}
return nil
},
}, func() {},
)
require.NoError(t1, err)
@@ -593,7 +601,7 @@ func newMemChannelGraph() *memChannelGraph {
//
// NOTE: Part of the autopilot.ChannelGraph interface.
func (m *memChannelGraph) ForEachNode(ctx context.Context,
cb func(context.Context, Node) error) error {
cb func(context.Context, Node) error, _ func()) error {
for _, node := range m.graph {
if err := cb(ctx, node); err != nil {

View File

@@ -60,6 +60,10 @@ func NewSimpleGraph(ctx context.Context, g ChannelGraph) (*SimpleGraph, error) {
return nil
},
)
}, func() {
clear(adj)
clear(nodes)
nextIndex = 0
})
if err != nil {
return nil, err