diff --git a/autopilot/agent.go b/autopilot/agent.go index 1b70580a3..30e76d3d9 100644 --- a/autopilot/agent.go +++ b/autopilot/agent.go @@ -640,6 +640,9 @@ func (a *Agent) openChans(ctx context.Context, availableFunds btcutil.Amount, nodes[nID] = struct{}{} return nil + }, func() { + nodes = nil + addresses = nil }); err != nil { return fmt.Errorf("unable to get graph nodes: %w", err) } diff --git a/autopilot/graph.go b/autopilot/graph.go index 3aa3ad055..fe74ad1fe 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -121,7 +121,7 @@ func (d *dbNode) ForEachChannel(ctx context.Context, // // NOTE: Part of the autopilot.ChannelGraph interface. func (d *databaseChannelGraph) ForEachNode(ctx context.Context, - cb func(context.Context, Node) error) error { + cb func(context.Context, Node) error, _ func()) error { return d.db.ForEachNode(ctx, func(nodeTx graphdb.NodeRTx) error { // We'll skip over any node that doesn't have any advertised @@ -217,7 +217,7 @@ func (nc dbNodeCached) ForEachChannel(ctx context.Context, // // NOTE: Part of the autopilot.ChannelGraph interface. func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context, - cb func(context.Context, Node) error) error { + cb func(context.Context, Node) error, reset func()) error { return dc.db.ForEachNodeCached(ctx, func(n route.Vertex, channels map[uint64]*graphdb.DirectedChannel) error { @@ -231,7 +231,7 @@ func (dc *databaseChannelGraphCached) ForEachNode(ctx context.Context, return cb(ctx, node) } return nil - }) + }, reset) } // memNode is a purely in-memory implementation of the autopilot.Node diff --git a/autopilot/interface.go b/autopilot/interface.go index 703f49020..69eb4e6f1 100644 --- a/autopilot/interface.go +++ b/autopilot/interface.go @@ -85,7 +85,8 @@ type ChannelGraph interface { // ForEachNode is a higher-order function that should be called once // for each connected node within the channel graph. If the passed // callback returns an error, then execution should be terminated. - ForEachNode(context.Context, func(context.Context, Node) error) error + ForEachNode(context.Context, func(context.Context, Node) error, + func()) error } // NodeScore is a tuple mapping a NodeID to a score indicating the preference @@ -235,5 +236,6 @@ type GraphSource interface { // ForEachNode since any further calls are made across multiple // transactions. ForEachNodeCached(ctx context.Context, cb func(node route.Vertex, - chans map[uint64]*graphdb.DirectedChannel) error) error + chans map[uint64]*graphdb.DirectedChannel) error, + reset func()) error } diff --git a/autopilot/prefattach.go b/autopilot/prefattach.go index 76d814a46..e9a9c53b4 100644 --- a/autopilot/prefattach.go +++ b/autopilot/prefattach.go @@ -105,6 +105,9 @@ func (p *PrefAttachment) NodeScores(ctx context.Context, g ChannelGraph, } return nil + }, func() { + allChans = nil + clear(seenChans) }); err != nil { return nil, err } @@ -162,6 +165,9 @@ func (p *PrefAttachment) NodeScores(ctx context.Context, g ChannelGraph, log.Tracef("Counted %v channels for node %x", nodeChans, nID[:]) return nil + }, func() { + maxChans = 0 + clear(nodeChanNum) }); err != nil { return nil, err } diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index 282e32625..b5c0b1d26 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -147,10 +147,13 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { // Get the score for all nodes found in the graph at // this point. nodes := make(map[NodeID]struct{}) - err = graph.ForEachNode(ctx, + err = graph.ForEachNode( + ctx, func(_ context.Context, n Node) error { nodes[n.PubKey()] = struct{}{} return nil + }, func() { + clear(nodes) }, ) require.NoError(t1, err) @@ -257,7 +260,12 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { twoChans = twoChans || (numChans == 2) return nil - }) + }, func() { + numNodes = 0 + twoChans = false + clear(nodes) + }, + ) require.NoError(t1, err) require.EqualValues(t1, 3, numNodes) @@ -338,7 +346,7 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) { nodes[n.PubKey()] = struct{}{} return nil - }, + }, func() {}, ) require.NoError(t1, err) @@ -593,7 +601,7 @@ func newMemChannelGraph() *memChannelGraph { // // NOTE: Part of the autopilot.ChannelGraph interface. func (m *memChannelGraph) ForEachNode(ctx context.Context, - cb func(context.Context, Node) error) error { + cb func(context.Context, Node) error, _ func()) error { for _, node := range m.graph { if err := cb(ctx, node); err != nil { diff --git a/autopilot/simple_graph.go b/autopilot/simple_graph.go index f028db3c7..2f80c490a 100644 --- a/autopilot/simple_graph.go +++ b/autopilot/simple_graph.go @@ -60,6 +60,10 @@ func NewSimpleGraph(ctx context.Context, g ChannelGraph) (*SimpleGraph, error) { return nil }, ) + }, func() { + clear(adj) + clear(nodes) + nextIndex = 0 }) if err != nil { return nil, err diff --git a/discovery/bootstrapper.go b/discovery/bootstrapper.go index dbfdbc7d2..1b6c981d4 100644 --- a/discovery/bootstrapper.go +++ b/discovery/bootstrapper.go @@ -247,9 +247,9 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, }) } - c.tried[nID] = struct{}{} - return errFound + }, func() { + a = nil }) if err != nil && !errors.Is(err, errFound) { return nil, err @@ -282,6 +282,14 @@ func (c *ChannelGraphBootstrapper) SampleNodeAddrs(_ context.Context, continue } + for _, addr := range sampleAddrs { + nID := autopilot.NodeID( + addr.IdentityKey.SerializeCompressed(), + ) + + c.tried[nID] = struct{}{} + } + addrs = append(addrs, sampleAddrs...) } diff --git a/graph/db/graph.go b/graph/db/graph.go index 3168a6618..2ae944e9e 100644 --- a/graph/db/graph.go +++ b/graph/db/graph.go @@ -246,14 +246,14 @@ func (c *ChannelGraph) GraphSession(cb func(graph NodeTraverser) error) error { // // NOTE: The callback contents MUST not be modified. func (c *ChannelGraph) ForEachNodeCached(ctx context.Context, - cb func(node route.Vertex, - chans map[uint64]*DirectedChannel) error) error { + cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error, + reset func()) error { if c.graphCache != nil { return c.graphCache.ForEachNode(cb) } - return c.V1Store.ForEachNodeCached(ctx, cb) + return c.V1Store.ForEachNodeCached(ctx, cb, reset) } // AddLightningNode adds a vertex/node to the graph database. If the node is not diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 8d62f65d4..20b097541 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1374,7 +1374,7 @@ func TestGraphTraversal(t *testing.T) { } return nil - }) + }, func() {}) require.NoError(t, err) // Iterate through all the known channels within the graph DB, once diff --git a/graph/db/interfaces.go b/graph/db/interfaces.go index 8c4b0c855..e47935c11 100644 --- a/graph/db/interfaces.go +++ b/graph/db/interfaces.go @@ -93,7 +93,7 @@ type V1Store interface { //nolint:interfacebloat // // NOTE: The callback contents MUST not be modified. ForEachNodeCached(ctx context.Context, cb func(node route.Vertex, - chans map[uint64]*DirectedChannel) error) error + chans map[uint64]*DirectedChannel) error, reset func()) error // ForEachNode iterates through all the stored vertices/nodes in the // graph, executing the passed callback with each node encountered. If diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 19d412d83..45625437c 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -678,10 +678,8 @@ func (c *KVStore) FetchNodeFeatures(nodePub route.Vertex) ( // // NOTE: The callback contents MUST not be modified. func (c *KVStore) ForEachNodeCached(_ context.Context, - cb func(node route.Vertex, - chans map[uint64]*DirectedChannel) error) error { - - reset := func() {} + cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error, + reset func()) error { // Otherwise call back to a version that uses the database directly. // We'll iterate over each node, then the set of channels for each diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index dc9365a17..ad7bfc0fe 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -1109,10 +1109,8 @@ func (s *SQLStore) ChanUpdatesInHorizon(startTime, // // NOTE: part of the V1Store interface. func (s *SQLStore) ForEachNodeCached(ctx context.Context, - cb func(node route.Vertex, - chans map[uint64]*DirectedChannel) error) error { - - reset := func() {} + cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error, + reset func()) error { return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { return forEachNodeCacheable(ctx, db, func(nodeID int64, diff --git a/rpcserver.go b/rpcserver.go index 0d003739f..39169a05b 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -7215,6 +7215,15 @@ func (r *rpcServer) GetNetworkInfo(ctx context.Context, } return nil + }, func() { + numChannels = 0 + numNodes = 0 + maxChanOut = 0 + totalNetworkCapacity = 0 + minChannelSize = math.MaxInt64 + maxChannelSize = 0 + allChans = nil + clear(seenChans) }) if err != nil { return nil, err diff --git a/server.go b/server.go index bfc1850d7..7eb8c583d 100644 --- a/server.go +++ b/server.go @@ -1254,7 +1254,7 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, DefaultRoutingPolicy: cc.RoutingPolicy, ForAllOutgoingChannels: func(ctx context.Context, cb func(*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy) error, + *models.ChannelEdgePolicy) error, reset func()) error { return s.graphDB.ForEachNodeChannel(ctx, selfVertex,