From cdcf0ac16bea41961f1ae9c17a332337c63615fb Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 7 Nov 2023 12:24:18 +0200 Subject: [PATCH] multi: use ChannelUpdate interface in various places --- discovery/gossiper.go | 59 +++++++++++++++++++++------------ discovery/gossiper_test.go | 7 ++-- discovery/message_store.go | 19 +++++++---- discovery/message_store_test.go | 8 ++--- funding/manager.go | 2 +- funding/manager_test.go | 14 ++++---- graph/validation_barrier.go | 12 +++---- peer/brontide.go | 7 ++++ 8 files changed, 79 insertions(+), 49 deletions(-) diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 785f66fb3..5339db426 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -939,10 +939,9 @@ type channelUpdateID struct { // retrieve all necessary data to validate the channel existence. channelID lnwire.ShortChannelID - // Flags least-significant bit must be set to 0 if the creating node - // corresponds to the first node in the previously sent channel - // announcement and 1 otherwise. - flags lnwire.ChanUpdateChanFlags + disabled bool + + direction bool } // msgWithSenders is a wrapper struct around a message, and the set of peers @@ -1051,32 +1050,49 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { // Channel updates are identified by the (short channel id, // channelflags) tuple. - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: sender := route.NewVertex(message.source) deDupKey := channelUpdateID{ - msg.ShortChannelID, - msg.ChannelFlags, + msg.SCID(), + msg.IsDisabled(), + msg.IsNode1(), } - oldTimestamp := uint32(0) + var ( + older = false + newer = true + ) mws, ok := d.channelUpdates[deDupKey] if ok { // If we already have seen this message, record its // timestamp. - update, ok := mws.msg.(*lnwire.ChannelUpdate1) + oldMsg, ok := mws.msg.(lnwire.ChannelUpdate) if !ok { - log.Errorf("Expected *lnwire.ChannelUpdate1, "+ - "got: %T", mws.msg) + log.Errorf("expected type "+ + "lnwire.ChannelUpdate, got: %T", + mws.msg) return } - oldTimestamp = update.Timestamp + cmp, err := msg.CmpAge(oldMsg) + if err != nil { + return + } + + newer = false + switch cmp { + case lnwire.LessThan: + older = true + case lnwire.GreaterThan: + newer = true + default: + } } // If we already had this message with a strictly newer // timestamp, then we'll just discard the message we got. - if oldTimestamp > msg.Timestamp { + if older { log.Debugf("Ignored outdated network message: "+ "peer=%v, msg=%s", message.peer, msg.MsgType()) return @@ -1085,7 +1101,7 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { // If the message we just got is newer than what we previously // have seen, or this is the first time we see it, then we'll // add it to our map of announcements. - if oldTimestamp < msg.Timestamp { + if newer { mws = msgWithSenders{ msg: msg, isLocal: !message.isRemote, @@ -1606,8 +1622,8 @@ func (d *AuthenticatedGossiper) isRecentlyRejectedMsg(msg lnwire.Message, var scid uint64 switch m := msg.(type) { - case *lnwire.ChannelUpdate1: - scid = m.ShortChannelID.ToUint64() + case lnwire.ChannelUpdate: + scid = m.SCID().ToUint64() case lnwire.ChannelAnnouncement: scid = m.SCID().ToUint64() @@ -2105,7 +2121,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( // ChannelEdgeInfo1 should be inspected. func (d *AuthenticatedGossiper) processZombieUpdate( chanInfo models.ChannelEdgeInfo, scid lnwire.ShortChannelID, - msg *lnwire.ChannelUpdate1) error { + msg lnwire.ChannelUpdate) error { // Since we've deemed the update as not stale above, before marking it // live, we'll make sure it has been signed by the correct party. If we @@ -2121,7 +2137,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate( } if pubKey == nil { return fmt.Errorf("incorrect pubkey to resurrect zombie "+ - "with chan_id=%v", msg.ShortChannelID) + "with chan_id=%v", msg.SCID()) } err := msg.VerifySig(pubKey) @@ -2129,7 +2145,6 @@ func (d *AuthenticatedGossiper) processZombieUpdate( return fmt.Errorf("unable to verify channel "+ "update signature: %v", err) } - // With the signature valid, we'll proceed to mark the // edge as live and wait for the channel announcement to // come through again. @@ -2144,13 +2159,13 @@ func (d *AuthenticatedGossiper) processZombieUpdate( case err != nil: return fmt.Errorf("unable to remove edge with "+ "chan_id=%v from zombie index: %v", - msg.ShortChannelID, err) + msg.SCID(), err) default: } log.Debugf("Removed edge with chan_id=%v from zombie "+ - "index", msg.ShortChannelID) + "index", msg.SCID()) return nil } @@ -2849,7 +2864,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg, // Reprocess the message, making sure we return an // error to the original caller in case the gossiper // shuts down. - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: log.Debugf("Reprocessing ChannelUpdate for "+ "shortChanID=%v", scid.ToUint64()) diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 84a77263d..7b29dc509 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -1884,7 +1884,8 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { assertChannelUpdate := func(channelUpdate *lnwire.ChannelUpdate1) { channelKey := channelUpdateID{ ua3.ShortChannelID, - ua3.ChannelFlags, + ua3.IsDisabled(), + ua3.IsNode1(), } mws, ok := announcements.channelUpdates[channelKey] @@ -2827,7 +2828,7 @@ func TestRetransmit(t *testing.T) { switch msg.(type) { case lnwire.ChannelAnnouncement: chanAnn++ - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: chanUpd++ case *lnwire.NodeAnnouncement: nodeAnn++ @@ -3314,7 +3315,7 @@ func TestSendChannelUpdateReliably(t *testing.T) { } switch msg := msg.(type) { - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: assertMessage(t, staleChannelUpdate, msg) case *lnwire.AnnounceSignatures1: assertMessage(t, batch.localProofAnn, msg) diff --git a/discovery/message_store.go b/discovery/message_store.go index e336c1281..156d56caa 100644 --- a/discovery/message_store.go +++ b/discovery/message_store.go @@ -85,8 +85,8 @@ func msgShortChanID(msg lnwire.Message) (lnwire.ShortChannelID, error) { switch msg := msg.(type) { case lnwire.AnnounceSignatures: shortChanID = msg.SCID() - case *lnwire.ChannelUpdate1: - shortChanID = msg.ShortChannelID + case lnwire.ChannelUpdate: + shortChanID = msg.SCID() default: return shortChanID, ErrUnsupportedMessage } @@ -160,7 +160,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, // In the event that we're attempting to delete a ChannelUpdate // from the store, we'll make sure that we're actually deleting // the correct one as it can be overwritten. - if msg, ok := msg.(*lnwire.ChannelUpdate1); ok { + if msg, ok := msg.(lnwire.ChannelUpdate); ok { // Deleting a value from a bucket that doesn't exist // acts as a NOP, so we'll return if a message doesn't // exist under this key. @@ -176,13 +176,18 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message, // If the timestamps don't match, then the update stored // should be the latest one, so we'll avoid deleting it. - m, ok := dbMsg.(*lnwire.ChannelUpdate1) + m, ok := dbMsg.(lnwire.ChannelUpdate) if !ok { return fmt.Errorf("expected "+ - "*lnwire.ChannelUpdate1, got: %T", - dbMsg) + "lnwire.ChannelUpdate, got: %T", dbMsg) } - if msg.Timestamp != m.Timestamp { + + diff, err := msg.CmpAge(m) + if err != nil { + return err + } + + if diff != lnwire.EqualTo { return nil } } diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go index 36c082e36..10189b902 100644 --- a/discovery/message_store_test.go +++ b/discovery/message_store_test.go @@ -116,10 +116,10 @@ func TestMessageStoreMessages(t *testing.T) { for _, msg := range peerMsgs { var shortChanID uint64 switch msg := msg.(type) { - case *lnwire.AnnounceSignatures1: - shortChanID = msg.ShortChannelID.ToUint64() - case *lnwire.ChannelUpdate1: - shortChanID = msg.ShortChannelID.ToUint64() + case lnwire.AnnounceSignatures: + shortChanID = msg.SCID().ToUint64() + case lnwire.ChannelUpdate: + shortChanID = msg.SCID().ToUint64() default: t.Fatalf("found unexpected message type %T", msg) } diff --git a/funding/manager.go b/funding/manager.go index 4f4334dd1..94a163dfe 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -4144,7 +4144,7 @@ func (f *Manager) ensureInitialForwardingPolicy(chanID lnwire.ChannelID, // send out to the network after a new channel has been created locally. type chanAnnouncement struct { chanAnn lnwire.ChannelAnnouncement - chanUpdateAnn *lnwire.ChannelUpdate1 + chanUpdateAnn lnwire.ChannelUpdate chanProof lnwire.AnnounceSignatures } diff --git a/funding/manager_test.go b/funding/manager_test.go index 26fd0ca3c..78788ffe6 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -1210,7 +1210,7 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, switch m := msg.(type) { case lnwire.ChannelAnnouncement: gotChannelAnnouncement = true - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: // The channel update sent by the node should // advertise the MinHTLC value required by the @@ -1225,31 +1225,33 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode, baseFee := aliceCfg.DefaultRoutingPolicy.BaseFee feeRate := aliceCfg.DefaultRoutingPolicy.FeeRate - require.EqualValues(t, 1, m.MessageFlags) + pol := m.ForwardingPolicy() + + require.True(t, pol.HasMaxHTLC) // We might expect a custom MinHTLC value. if len(customMinHtlc) > 0 { minHtlc = customMinHtlc[j] } - require.Equal(t, minHtlc, m.HtlcMinimumMsat) + require.Equal(t, minHtlc, pol.MinHTLC) // We might expect a custom MaxHltc value. if len(customMaxHtlc) > 0 { maxHtlc = customMaxHtlc[j] } - require.Equal(t, maxHtlc, m.HtlcMaximumMsat) + require.Equal(t, maxHtlc, pol.MaxHTLC) // We might expect a custom baseFee value. if len(baseFees) > 0 { baseFee = baseFees[j] } - require.EqualValues(t, baseFee, m.BaseFee) + require.EqualValues(t, baseFee, pol.BaseFee) // We might expect a custom feeRate value. if len(feeRates) > 0 { feeRate = feeRates[j] } - require.EqualValues(t, feeRate, m.FeeRate) + require.EqualValues(t, feeRate, pol.FeeRate) gotChannelUpdate = true } diff --git a/graph/validation_barrier.go b/graph/validation_barrier.go index c1de127ba..74ca89620 100644 --- a/graph/validation_barrier.go +++ b/graph/validation_barrier.go @@ -146,7 +146,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // initialization needs to be done beyond just occupying a job slot. case models.ChannelEdgePolicy: return - case *lnwire.ChannelUpdate1: + case lnwire.ChannelUpdate: return case *lnwire.NodeAnnouncement: // TODO(roasbeef): node ann needs to wait on existing channel updates @@ -201,11 +201,11 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { jobDesc = fmt.Sprintf("job=channeldb.LightningNode, pub=%s", vertex) - case *lnwire.ChannelUpdate1: - signals, ok = v.chanEdgeDependencies[msg.ShortChannelID] + case lnwire.ChannelUpdate: + signals, ok = v.chanEdgeDependencies[msg.SCID()] jobDesc = fmt.Sprintf("job=lnwire.ChannelUpdate, scid=%v", - msg.ShortChannelID.ToUint64()) + msg.SCID().ToUint64()) case *lnwire.NodeAnnouncement: vertex := route.Vertex(msg.NodeID) @@ -296,8 +296,8 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { delete(v.nodeAnnDependencies, route.Vertex(msg.PubKeyBytes)) case *lnwire.NodeAnnouncement: delete(v.nodeAnnDependencies, route.Vertex(msg.NodeID)) - case *lnwire.ChannelUpdate1: - delete(v.chanEdgeDependencies, msg.ShortChannelID) + case lnwire.ChannelUpdate: + delete(v.chanEdgeDependencies, msg.SCID()) case models.ChannelEdgePolicy: delete(v.chanEdgeDependencies, msg.SCID()) diff --git a/peer/brontide.go b/peer/brontide.go index b40f40b88..e1289959b 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -1966,6 +1966,7 @@ out: } case *lnwire.ChannelUpdate1, + *lnwire.ChannelUpdate2, *lnwire.ChannelAnnouncement1, *lnwire.ChannelAnnouncement2, *lnwire.NodeAnnouncement, @@ -2242,6 +2243,12 @@ func messageSummary(msg lnwire.Message) string { msg.ShortChannelID.ToUint64(), msg.MessageFlags, msg.ChannelFlags, time.Unix(int64(msg.Timestamp), 0)) + case *lnwire.ChannelUpdate2: + return fmt.Sprintf("chain_hash=%v, short_chan_id=%v, "+ + "is_disabled=%v, is_node_1=%v, block_height=%v", + msg.ChainHash, msg.ShortChannelID.Val.ToUint64(), + msg.IsDisabled(), msg.IsNode1(), msg.BlockHeight) + case *lnwire.NodeAnnouncement: return fmt.Sprintf("node=%x, update_time=%v", msg.NodeID, time.Unix(int64(msg.Timestamp), 0))