diff --git a/discovery/chan_series.go b/discovery/chan_series.go index 081133dfc..22d6397e0 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -2,6 +2,7 @@ package discovery import ( "context" + "iter" "time" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -30,8 +31,8 @@ type ChannelGraphTimeSeries interface { // update timestamp between the start time and end time. We'll use this // to catch up a remote node to the set of channel updates that they // may have missed out on within the target chain. - UpdatesInHorizon(chain chainhash.Hash, - startTime time.Time, endTime time.Time) ([]lnwire.Message, error) + UpdatesInHorizon(chain chainhash.Hash, startTime time.Time, + endTime time.Time) iter.Seq2[lnwire.Message, error] // FilterKnownChanIDs takes a target chain, and a set of channel ID's, // and returns a filtered set of chan ID's. This filtered set of chan @@ -108,109 +109,96 @@ func (c *ChanSeries) HighestChanID(ctx context.Context, // // NOTE: This is part of the ChannelGraphTimeSeries interface. func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, - startTime time.Time, endTime time.Time) ([]lnwire.Message, error) { + startTime, endTime time.Time) iter.Seq2[lnwire.Message, error] { - var updates []lnwire.Message - - // First, we'll query for all the set of channels that have an update - // that falls within the specified horizon. - chansInHorizonIter, err := c.graph.ChanUpdatesInHorizon( - startTime, endTime, - ) - if err != nil { - return nil, err - } - - for channel := range chansInHorizonIter { - // If the channel hasn't been fully advertised yet, or is a - // private channel, then we'll skip it as we can't construct a - // full authentication proof if one is requested. - if channel.Info.AuthProof == nil { - continue - } - - chanAnn, edge1, edge2, err := netann.CreateChanAnnouncement( - channel.Info.AuthProof, channel.Info, channel.Policy1, - channel.Policy2, + return func(yield func(lnwire.Message, error) bool) { + // First, we'll query for all the set of channels that have an + // update that falls within the specified horizon. + chansInHorizon, err := c.graph.ChanUpdatesInHorizon( + startTime, endTime, ) if err != nil { - return nil, err + yield(nil, err) + return } - // Create a slice to hold the `channel_announcement` and - // potentially two `channel_update` msgs. - // - // NOTE: Based on BOLT7, if a channel_announcement has no - // corresponding channel_updates, we must not send the - // channel_announcement. Thus we use this slice to decide we - // want to send this `channel_announcement` or not. By the end - // of the operation, if the len of the slice is 1, we will not - // send the `channel_announcement`. Otherwise, when sending the - // msgs, the `channel_announcement` must be sent prior to any - // corresponding `channel_update` or `node_annoucement`, that's - // why we create a slice here to maintain the order. - chanUpdates := make([]lnwire.Message, 0, 3) - chanUpdates = append(chanUpdates, chanAnn) + for channel := range chansInHorizon { + // If the channel hasn't been fully advertised yet, or + // is a private channel, then we'll skip it as we can't + // construct a full authentication proof if one is + // requested. + if channel.Info.AuthProof == nil { + continue + } + + //nolint:ll + chanAnn, edge1, edge2, err := netann.CreateChanAnnouncement( + channel.Info.AuthProof, channel.Info, + channel.Policy1, channel.Policy2, + ) + if err != nil { + if !yield(nil, err) { + return + } + + continue + } + + if !yield(chanAnn, nil) { + return + } - if edge1 != nil { // We don't want to send channel updates that don't - // conform to the spec (anymore). - err := netann.ValidateChannelUpdateFields(0, edge1) - if err != nil { - log.Errorf("not sending invalid channel "+ - "update %v: %v", edge1, err) - } else { - chanUpdates = append(chanUpdates, edge1) + // conform to the spec (anymore), so check to make sure + // that these channel updates are valid before yielding + // them. + if edge1 != nil { + err := netann.ValidateChannelUpdateFields( + 0, edge1, + ) + if err != nil { + log.Errorf("not sending invalid "+ + "channel update %v: %v", + edge1, err) + } else if !yield(edge1, nil) { + return + } + } + if edge2 != nil { + err := netann.ValidateChannelUpdateFields( + 0, edge2, + ) + if err != nil { + log.Errorf("not sending invalid "+ + "channel update %v: %v", edge2, + err) + } else if !yield(edge2, nil) { + return + } } } - if edge2 != nil { - err := netann.ValidateChannelUpdateFields(0, edge2) + // Next, we'll send out all the node announcements that have an + // update within the horizon as well. We send these second to + // ensure that they follow any active channels they have. + nodeAnnsInHorizon, err := c.graph.NodeUpdatesInHorizon( + startTime, endTime, graphdb.WithIterPublicNodesOnly(), + ) + for nodeAnn := range nodeAnnsInHorizon { + nodeUpdate, err := nodeAnn.NodeAnnouncement(true) if err != nil { - log.Errorf("not sending invalid channel "+ - "update %v: %v", edge2, err) - } else { - chanUpdates = append(chanUpdates, edge2) + if !yield(nil, err) { + return + } + + continue + } + + if !yield(nodeUpdate, nil) { + return } } - - // If there's no corresponding `channel_update` to send, skip - // sending this `channel_announcement`. - if len(chanUpdates) < 2 { - continue - } - - // Append the all the msgs to the slice. - updates = append(updates, chanUpdates...) } - - // Next, we'll send out all the node announcements that have an update - // within the horizon as well. We send these second to ensure that they - // follow any active channels they have. - nodeAnnsInHorizon, err := c.graph.NodeUpdatesInHorizon( - startTime, endTime, graphdb.WithIterPublicNodesOnly(), - ) - if err != nil { - return nil, err - } - - for nodeAnn := range nodeAnnsInHorizon { - nodeUpdate, err := nodeAnn.NodeAnnouncement(true) - if err != nil { - return nil, err - } - - if err := netann.ValidateNodeAnnFields(nodeUpdate); err != nil { - log.Debugf("Skipping forwarding invalid node "+ - "announcement %x: %v", nodeAnn.PubKeyBytes, err) - - continue - } - - updates = append(updates, nodeUpdate) - } - - return updates, nil } // FilterKnownChanIDs takes a target chain, and a set of channel ID's, and diff --git a/discovery/syncer.go b/discovery/syncer.go index 12e7dbf16..478e7acd3 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -1442,23 +1442,12 @@ func (g *GossipSyncer) ApplyGossipFilter(ctx context.Context, // Now that the remote peer has applied their filter, we'll query the // database for all the messages that are beyond this filter. - newUpdatestoSend, err := g.cfg.channelSeries.UpdatesInHorizon( + newUpdatestoSend := g.cfg.channelSeries.UpdatesInHorizon( g.cfg.chainHash, startTime, endTime, ) - if err != nil { - returnSema() - return err - } log.Infof("GossipSyncer(%x): applying new remote update horizon: "+ - "start=%v, end=%v, backlog_size=%v", g.cfg.peerPub[:], - startTime, endTime, len(newUpdatestoSend)) - - // If we don't have any to send, then we can return early. - if len(newUpdatestoSend) == 0 { - returnSema() - return nil - } + "start=%v, end=%v", g.cfg.peerPub[:], startTime, endTime) // Set the atomic flag to indicate we're starting to send the backlog. // If the swap fails, it means another goroutine is already active, so @@ -1478,7 +1467,7 @@ func (g *GossipSyncer) ApplyGossipFilter(ctx context.Context, defer returnSema() defer g.isSendingBacklog.Store(false) - for _, msg := range newUpdatestoSend { + for msg := range newUpdatestoSend { err := g.sendToPeerSync(ctx, msg) switch { case err == ErrGossipSyncerExiting: diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index edf91cb69..72ad2ce47 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "iter" "math" "reflect" "sort" @@ -86,13 +87,20 @@ func (m *mockChannelGraphTimeSeries) HighestChanID(_ context.Context, } func (m *mockChannelGraphTimeSeries) UpdatesInHorizon(chain chainhash.Hash, - startTime time.Time, endTime time.Time) ([]lnwire.Message, error) { + startTime, endTime time.Time) iter.Seq2[lnwire.Message, error] { - m.horizonReq <- horizonQuery{ - chain, startTime, endTime, + return func(yield func(lnwire.Message, error) bool) { + m.horizonReq <- horizonQuery{ + chain, startTime, endTime, + } + + msgs := <-m.horizonResp + for _, msg := range msgs { + if !yield(msg, nil) { + return + } + } } - - return <-m.horizonResp, nil } func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash,