discovery: convert UpdatesInHorizon to return iter.Seq2[lnwire.Message, error]

In this commit, we complete the iterator conversion work started in PR
10128 by threading the iterator pattern through to the higher-level
UpdatesInHorizon method. This change converts the method from returning
a fully materialized slice of messages to returning a lazy iterator that
yields messages on demand.

The new signature uses iter.Seq2 to allow error propagation during
iteration, eliminating the need for a separate error return value. This
approach enables callers to handle errors as they occur during iteration
rather than failing upfront.

The implementation now lazily processes channel and node updates,
yielding them as they're generated rather than accumulating them in
memory. This maintains the same ordering guarantees (channels before
nodes) while significantly reducing memory pressure when dealing with
large update sets during gossip synchronization.
This commit is contained in:
Olaoluwa Osuntokun
2025-09-10 18:23:00 -07:00
parent fda989da9c
commit d8f6fd29f7
3 changed files with 95 additions and 110 deletions

View File

@@ -2,6 +2,7 @@ package discovery
import ( import (
"context" "context"
"iter"
"time" "time"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
@@ -30,8 +31,8 @@ type ChannelGraphTimeSeries interface {
// update timestamp between the start time and end time. We'll use this // update timestamp between the start time and end time. We'll use this
// to catch up a remote node to the set of channel updates that they // to catch up a remote node to the set of channel updates that they
// may have missed out on within the target chain. // may have missed out on within the target chain.
UpdatesInHorizon(chain chainhash.Hash, UpdatesInHorizon(chain chainhash.Hash, startTime time.Time,
startTime time.Time, endTime time.Time) ([]lnwire.Message, error) endTime time.Time) iter.Seq2[lnwire.Message, error]
// FilterKnownChanIDs takes a target chain, and a set of channel ID's, // FilterKnownChanIDs takes a target chain, and a set of channel ID's,
// and returns a filtered set of chan ID's. This filtered set of chan // and returns a filtered set of chan ID's. This filtered set of chan
@@ -108,109 +109,96 @@ func (c *ChanSeries) HighestChanID(ctx context.Context,
// //
// NOTE: This is part of the ChannelGraphTimeSeries interface. // NOTE: This is part of the ChannelGraphTimeSeries interface.
func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash,
startTime time.Time, endTime time.Time) ([]lnwire.Message, error) { startTime, endTime time.Time) iter.Seq2[lnwire.Message, error] {
var updates []lnwire.Message return func(yield func(lnwire.Message, error) bool) {
// First, we'll query for all the set of channels that have an
// First, we'll query for all the set of channels that have an update // update that falls within the specified horizon.
// that falls within the specified horizon. chansInHorizon, err := c.graph.ChanUpdatesInHorizon(
chansInHorizonIter, err := c.graph.ChanUpdatesInHorizon( startTime, endTime,
startTime, endTime,
)
if err != nil {
return nil, err
}
for channel := range chansInHorizonIter {
// If the channel hasn't been fully advertised yet, or is a
// private channel, then we'll skip it as we can't construct a
// full authentication proof if one is requested.
if channel.Info.AuthProof == nil {
continue
}
chanAnn, edge1, edge2, err := netann.CreateChanAnnouncement(
channel.Info.AuthProof, channel.Info, channel.Policy1,
channel.Policy2,
) )
if err != nil { if err != nil {
return nil, err yield(nil, err)
return
} }
// Create a slice to hold the `channel_announcement` and for channel := range chansInHorizon {
// potentially two `channel_update` msgs. // If the channel hasn't been fully advertised yet, or
// // is a private channel, then we'll skip it as we can't
// NOTE: Based on BOLT7, if a channel_announcement has no // construct a full authentication proof if one is
// corresponding channel_updates, we must not send the // requested.
// channel_announcement. Thus we use this slice to decide we if channel.Info.AuthProof == nil {
// want to send this `channel_announcement` or not. By the end continue
// of the operation, if the len of the slice is 1, we will not }
// send the `channel_announcement`. Otherwise, when sending the
// msgs, the `channel_announcement` must be sent prior to any //nolint:ll
// corresponding `channel_update` or `node_annoucement`, that's chanAnn, edge1, edge2, err := netann.CreateChanAnnouncement(
// why we create a slice here to maintain the order. channel.Info.AuthProof, channel.Info,
chanUpdates := make([]lnwire.Message, 0, 3) channel.Policy1, channel.Policy2,
chanUpdates = append(chanUpdates, chanAnn) )
if err != nil {
if !yield(nil, err) {
return
}
continue
}
if !yield(chanAnn, nil) {
return
}
if edge1 != nil {
// We don't want to send channel updates that don't // We don't want to send channel updates that don't
// conform to the spec (anymore). // conform to the spec (anymore), so check to make sure
err := netann.ValidateChannelUpdateFields(0, edge1) // that these channel updates are valid before yielding
if err != nil { // them.
log.Errorf("not sending invalid channel "+ if edge1 != nil {
"update %v: %v", edge1, err) err := netann.ValidateChannelUpdateFields(
} else { 0, edge1,
chanUpdates = append(chanUpdates, edge1) )
if err != nil {
log.Errorf("not sending invalid "+
"channel update %v: %v",
edge1, err)
} else if !yield(edge1, nil) {
return
}
}
if edge2 != nil {
err := netann.ValidateChannelUpdateFields(
0, edge2,
)
if err != nil {
log.Errorf("not sending invalid "+
"channel update %v: %v", edge2,
err)
} else if !yield(edge2, nil) {
return
}
} }
} }
if edge2 != nil { // Next, we'll send out all the node announcements that have an
err := netann.ValidateChannelUpdateFields(0, edge2) // update within the horizon as well. We send these second to
// ensure that they follow any active channels they have.
nodeAnnsInHorizon, err := c.graph.NodeUpdatesInHorizon(
startTime, endTime, graphdb.WithIterPublicNodesOnly(),
)
for nodeAnn := range nodeAnnsInHorizon {
nodeUpdate, err := nodeAnn.NodeAnnouncement(true)
if err != nil { if err != nil {
log.Errorf("not sending invalid channel "+ if !yield(nil, err) {
"update %v: %v", edge2, err) return
} else { }
chanUpdates = append(chanUpdates, edge2)
continue
}
if !yield(nodeUpdate, nil) {
return
} }
} }
// If there's no corresponding `channel_update` to send, skip
// sending this `channel_announcement`.
if len(chanUpdates) < 2 {
continue
}
// Append the all the msgs to the slice.
updates = append(updates, chanUpdates...)
} }
// Next, we'll send out all the node announcements that have an update
// within the horizon as well. We send these second to ensure that they
// follow any active channels they have.
nodeAnnsInHorizon, err := c.graph.NodeUpdatesInHorizon(
startTime, endTime, graphdb.WithIterPublicNodesOnly(),
)
if err != nil {
return nil, err
}
for nodeAnn := range nodeAnnsInHorizon {
nodeUpdate, err := nodeAnn.NodeAnnouncement(true)
if err != nil {
return nil, err
}
if err := netann.ValidateNodeAnnFields(nodeUpdate); err != nil {
log.Debugf("Skipping forwarding invalid node "+
"announcement %x: %v", nodeAnn.PubKeyBytes, err)
continue
}
updates = append(updates, nodeUpdate)
}
return updates, nil
} }
// FilterKnownChanIDs takes a target chain, and a set of channel ID's, and // FilterKnownChanIDs takes a target chain, and a set of channel ID's, and

View File

@@ -1442,23 +1442,12 @@ func (g *GossipSyncer) ApplyGossipFilter(ctx context.Context,
// Now that the remote peer has applied their filter, we'll query the // Now that the remote peer has applied their filter, we'll query the
// database for all the messages that are beyond this filter. // database for all the messages that are beyond this filter.
newUpdatestoSend, err := g.cfg.channelSeries.UpdatesInHorizon( newUpdatestoSend := g.cfg.channelSeries.UpdatesInHorizon(
g.cfg.chainHash, startTime, endTime, g.cfg.chainHash, startTime, endTime,
) )
if err != nil {
returnSema()
return err
}
log.Infof("GossipSyncer(%x): applying new remote update horizon: "+ log.Infof("GossipSyncer(%x): applying new remote update horizon: "+
"start=%v, end=%v, backlog_size=%v", g.cfg.peerPub[:], "start=%v, end=%v", g.cfg.peerPub[:], startTime, endTime)
startTime, endTime, len(newUpdatestoSend))
// If we don't have any to send, then we can return early.
if len(newUpdatestoSend) == 0 {
returnSema()
return nil
}
// Set the atomic flag to indicate we're starting to send the backlog. // Set the atomic flag to indicate we're starting to send the backlog.
// If the swap fails, it means another goroutine is already active, so // If the swap fails, it means another goroutine is already active, so
@@ -1478,7 +1467,7 @@ func (g *GossipSyncer) ApplyGossipFilter(ctx context.Context,
defer returnSema() defer returnSema()
defer g.isSendingBacklog.Store(false) defer g.isSendingBacklog.Store(false)
for _, msg := range newUpdatestoSend { for msg := range newUpdatestoSend {
err := g.sendToPeerSync(ctx, msg) err := g.sendToPeerSync(ctx, msg)
switch { switch {
case err == ErrGossipSyncerExiting: case err == ErrGossipSyncerExiting:

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"iter"
"math" "math"
"reflect" "reflect"
"sort" "sort"
@@ -86,13 +87,20 @@ func (m *mockChannelGraphTimeSeries) HighestChanID(_ context.Context,
} }
func (m *mockChannelGraphTimeSeries) UpdatesInHorizon(chain chainhash.Hash, func (m *mockChannelGraphTimeSeries) UpdatesInHorizon(chain chainhash.Hash,
startTime time.Time, endTime time.Time) ([]lnwire.Message, error) { startTime, endTime time.Time) iter.Seq2[lnwire.Message, error] {
m.horizonReq <- horizonQuery{ return func(yield func(lnwire.Message, error) bool) {
chain, startTime, endTime, m.horizonReq <- horizonQuery{
chain, startTime, endTime,
}
msgs := <-m.horizonResp
for _, msg := range msgs {
if !yield(msg, nil) {
return
}
}
} }
return <-m.horizonResp, nil
} }
func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash, func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash,