multi: use ChannelUpdate interface in various places

This commit is contained in:
Elle Mouton 2023-11-07 12:24:18 +02:00
parent ad0de06319
commit cdcf0ac16b
No known key found for this signature in database
GPG Key ID: D7D916376026F177
8 changed files with 79 additions and 49 deletions

View File

@ -939,10 +939,9 @@ type channelUpdateID struct {
// retrieve all necessary data to validate the channel existence.
channelID lnwire.ShortChannelID
// Flags least-significant bit must be set to 0 if the creating node
// corresponds to the first node in the previously sent channel
// announcement and 1 otherwise.
flags lnwire.ChanUpdateChanFlags
disabled bool
direction bool
}
// msgWithSenders is a wrapper struct around a message, and the set of peers
@ -1051,32 +1050,49 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) {
// Channel updates are identified by the (short channel id,
// channelflags) tuple.
case *lnwire.ChannelUpdate1:
case lnwire.ChannelUpdate:
sender := route.NewVertex(message.source)
deDupKey := channelUpdateID{
msg.ShortChannelID,
msg.ChannelFlags,
msg.SCID(),
msg.IsDisabled(),
msg.IsNode1(),
}
oldTimestamp := uint32(0)
var (
older = false
newer = true
)
mws, ok := d.channelUpdates[deDupKey]
if ok {
// If we already have seen this message, record its
// timestamp.
update, ok := mws.msg.(*lnwire.ChannelUpdate1)
oldMsg, ok := mws.msg.(lnwire.ChannelUpdate)
if !ok {
log.Errorf("Expected *lnwire.ChannelUpdate1, "+
"got: %T", mws.msg)
log.Errorf("expected type "+
"lnwire.ChannelUpdate, got: %T",
mws.msg)
return
}
oldTimestamp = update.Timestamp
cmp, err := msg.CmpAge(oldMsg)
if err != nil {
return
}
newer = false
switch cmp {
case lnwire.LessThan:
older = true
case lnwire.GreaterThan:
newer = true
default:
}
}
// If we already had this message with a strictly newer
// timestamp, then we'll just discard the message we got.
if oldTimestamp > msg.Timestamp {
if older {
log.Debugf("Ignored outdated network message: "+
"peer=%v, msg=%s", message.peer, msg.MsgType())
return
@ -1085,7 +1101,7 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) {
// If the message we just got is newer than what we previously
// have seen, or this is the first time we see it, then we'll
// add it to our map of announcements.
if oldTimestamp < msg.Timestamp {
if newer {
mws = msgWithSenders{
msg: msg,
isLocal: !message.isRemote,
@ -1606,8 +1622,8 @@ func (d *AuthenticatedGossiper) isRecentlyRejectedMsg(msg lnwire.Message,
var scid uint64
switch m := msg.(type) {
case *lnwire.ChannelUpdate1:
scid = m.ShortChannelID.ToUint64()
case lnwire.ChannelUpdate:
scid = m.SCID().ToUint64()
case lnwire.ChannelAnnouncement:
scid = m.SCID().ToUint64()
@ -2105,7 +2121,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
// ChannelEdgeInfo1 should be inspected.
func (d *AuthenticatedGossiper) processZombieUpdate(
chanInfo models.ChannelEdgeInfo, scid lnwire.ShortChannelID,
msg *lnwire.ChannelUpdate1) error {
msg lnwire.ChannelUpdate) error {
// Since we've deemed the update as not stale above, before marking it
// live, we'll make sure it has been signed by the correct party. If we
@ -2121,7 +2137,7 @@ func (d *AuthenticatedGossiper) processZombieUpdate(
}
if pubKey == nil {
return fmt.Errorf("incorrect pubkey to resurrect zombie "+
"with chan_id=%v", msg.ShortChannelID)
"with chan_id=%v", msg.SCID())
}
err := msg.VerifySig(pubKey)
@ -2129,7 +2145,6 @@ func (d *AuthenticatedGossiper) processZombieUpdate(
return fmt.Errorf("unable to verify channel "+
"update signature: %v", err)
}
// With the signature valid, we'll proceed to mark the
// edge as live and wait for the channel announcement to
// come through again.
@ -2144,13 +2159,13 @@ func (d *AuthenticatedGossiper) processZombieUpdate(
case err != nil:
return fmt.Errorf("unable to remove edge with "+
"chan_id=%v from zombie index: %v",
msg.ShortChannelID, err)
msg.SCID(), err)
default:
}
log.Debugf("Removed edge with chan_id=%v from zombie "+
"index", msg.ShortChannelID)
"index", msg.SCID())
return nil
}
@ -2849,7 +2864,7 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(nMsg *networkMsg,
// Reprocess the message, making sure we return an
// error to the original caller in case the gossiper
// shuts down.
case *lnwire.ChannelUpdate1:
case lnwire.ChannelUpdate:
log.Debugf("Reprocessing ChannelUpdate for "+
"shortChanID=%v", scid.ToUint64())

View File

@ -1884,7 +1884,8 @@ func TestDeDuplicatedAnnouncements(t *testing.T) {
assertChannelUpdate := func(channelUpdate *lnwire.ChannelUpdate1) {
channelKey := channelUpdateID{
ua3.ShortChannelID,
ua3.ChannelFlags,
ua3.IsDisabled(),
ua3.IsNode1(),
}
mws, ok := announcements.channelUpdates[channelKey]
@ -2827,7 +2828,7 @@ func TestRetransmit(t *testing.T) {
switch msg.(type) {
case lnwire.ChannelAnnouncement:
chanAnn++
case *lnwire.ChannelUpdate1:
case lnwire.ChannelUpdate:
chanUpd++
case *lnwire.NodeAnnouncement:
nodeAnn++
@ -3314,7 +3315,7 @@ func TestSendChannelUpdateReliably(t *testing.T) {
}
switch msg := msg.(type) {
case *lnwire.ChannelUpdate1:
case lnwire.ChannelUpdate:
assertMessage(t, staleChannelUpdate, msg)
case *lnwire.AnnounceSignatures1:
assertMessage(t, batch.localProofAnn, msg)

View File

@ -85,8 +85,8 @@ func msgShortChanID(msg lnwire.Message) (lnwire.ShortChannelID, error) {
switch msg := msg.(type) {
case lnwire.AnnounceSignatures:
shortChanID = msg.SCID()
case *lnwire.ChannelUpdate1:
shortChanID = msg.ShortChannelID
case lnwire.ChannelUpdate:
shortChanID = msg.SCID()
default:
return shortChanID, ErrUnsupportedMessage
}
@ -160,7 +160,7 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message,
// In the event that we're attempting to delete a ChannelUpdate
// from the store, we'll make sure that we're actually deleting
// the correct one as it can be overwritten.
if msg, ok := msg.(*lnwire.ChannelUpdate1); ok {
if msg, ok := msg.(lnwire.ChannelUpdate); ok {
// Deleting a value from a bucket that doesn't exist
// acts as a NOP, so we'll return if a message doesn't
// exist under this key.
@ -176,13 +176,18 @@ func (s *MessageStore) DeleteMessage(msg lnwire.Message,
// If the timestamps don't match, then the update stored
// should be the latest one, so we'll avoid deleting it.
m, ok := dbMsg.(*lnwire.ChannelUpdate1)
m, ok := dbMsg.(lnwire.ChannelUpdate)
if !ok {
return fmt.Errorf("expected "+
"*lnwire.ChannelUpdate1, got: %T",
dbMsg)
"lnwire.ChannelUpdate, got: %T", dbMsg)
}
if msg.Timestamp != m.Timestamp {
diff, err := msg.CmpAge(m)
if err != nil {
return err
}
if diff != lnwire.EqualTo {
return nil
}
}

View File

@ -116,10 +116,10 @@ func TestMessageStoreMessages(t *testing.T) {
for _, msg := range peerMsgs {
var shortChanID uint64
switch msg := msg.(type) {
case *lnwire.AnnounceSignatures1:
shortChanID = msg.ShortChannelID.ToUint64()
case *lnwire.ChannelUpdate1:
shortChanID = msg.ShortChannelID.ToUint64()
case lnwire.AnnounceSignatures:
shortChanID = msg.SCID().ToUint64()
case lnwire.ChannelUpdate:
shortChanID = msg.SCID().ToUint64()
default:
t.Fatalf("found unexpected message type %T", msg)
}

View File

@ -4144,7 +4144,7 @@ func (f *Manager) ensureInitialForwardingPolicy(chanID lnwire.ChannelID,
// send out to the network after a new channel has been created locally.
type chanAnnouncement struct {
chanAnn lnwire.ChannelAnnouncement
chanUpdateAnn *lnwire.ChannelUpdate1
chanUpdateAnn lnwire.ChannelUpdate
chanProof lnwire.AnnounceSignatures
}

View File

@ -1210,7 +1210,7 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode,
switch m := msg.(type) {
case lnwire.ChannelAnnouncement:
gotChannelAnnouncement = true
case *lnwire.ChannelUpdate1:
case lnwire.ChannelUpdate:
// The channel update sent by the node should
// advertise the MinHTLC value required by the
@ -1225,31 +1225,33 @@ func assertChannelAnnouncements(t *testing.T, alice, bob *testNode,
baseFee := aliceCfg.DefaultRoutingPolicy.BaseFee
feeRate := aliceCfg.DefaultRoutingPolicy.FeeRate
require.EqualValues(t, 1, m.MessageFlags)
pol := m.ForwardingPolicy()
require.True(t, pol.HasMaxHTLC)
// We might expect a custom MinHTLC value.
if len(customMinHtlc) > 0 {
minHtlc = customMinHtlc[j]
}
require.Equal(t, minHtlc, m.HtlcMinimumMsat)
require.Equal(t, minHtlc, pol.MinHTLC)
// We might expect a custom MaxHltc value.
if len(customMaxHtlc) > 0 {
maxHtlc = customMaxHtlc[j]
}
require.Equal(t, maxHtlc, m.HtlcMaximumMsat)
require.Equal(t, maxHtlc, pol.MaxHTLC)
// We might expect a custom baseFee value.
if len(baseFees) > 0 {
baseFee = baseFees[j]
}
require.EqualValues(t, baseFee, m.BaseFee)
require.EqualValues(t, baseFee, pol.BaseFee)
// We might expect a custom feeRate value.
if len(feeRates) > 0 {
feeRate = feeRates[j]
}
require.EqualValues(t, feeRate, m.FeeRate)
require.EqualValues(t, feeRate, pol.FeeRate)
gotChannelUpdate = true
}

View File

@ -146,7 +146,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) {
// initialization needs to be done beyond just occupying a job slot.
case models.ChannelEdgePolicy:
return
case *lnwire.ChannelUpdate1:
case lnwire.ChannelUpdate:
return
case *lnwire.NodeAnnouncement:
// TODO(roasbeef): node ann needs to wait on existing channel updates
@ -201,11 +201,11 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error {
jobDesc = fmt.Sprintf("job=channeldb.LightningNode, pub=%s",
vertex)
case *lnwire.ChannelUpdate1:
signals, ok = v.chanEdgeDependencies[msg.ShortChannelID]
case lnwire.ChannelUpdate:
signals, ok = v.chanEdgeDependencies[msg.SCID()]
jobDesc = fmt.Sprintf("job=lnwire.ChannelUpdate, scid=%v",
msg.ShortChannelID.ToUint64())
msg.SCID().ToUint64())
case *lnwire.NodeAnnouncement:
vertex := route.Vertex(msg.NodeID)
@ -296,8 +296,8 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) {
delete(v.nodeAnnDependencies, route.Vertex(msg.PubKeyBytes))
case *lnwire.NodeAnnouncement:
delete(v.nodeAnnDependencies, route.Vertex(msg.NodeID))
case *lnwire.ChannelUpdate1:
delete(v.chanEdgeDependencies, msg.ShortChannelID)
case lnwire.ChannelUpdate:
delete(v.chanEdgeDependencies, msg.SCID())
case models.ChannelEdgePolicy:
delete(v.chanEdgeDependencies, msg.SCID())

View File

@ -1966,6 +1966,7 @@ out:
}
case *lnwire.ChannelUpdate1,
*lnwire.ChannelUpdate2,
*lnwire.ChannelAnnouncement1,
*lnwire.ChannelAnnouncement2,
*lnwire.NodeAnnouncement,
@ -2242,6 +2243,12 @@ func messageSummary(msg lnwire.Message) string {
msg.ShortChannelID.ToUint64(), msg.MessageFlags,
msg.ChannelFlags, time.Unix(int64(msg.Timestamp), 0))
case *lnwire.ChannelUpdate2:
return fmt.Sprintf("chain_hash=%v, short_chan_id=%v, "+
"is_disabled=%v, is_node_1=%v, block_height=%v",
msg.ChainHash, msg.ShortChannelID.Val.ToUint64(),
msg.IsDisabled(), msg.IsNode1(), msg.BlockHeight)
case *lnwire.NodeAnnouncement:
return fmt.Sprintf("node=%x, update_time=%v",
msg.NodeID, time.Unix(int64(msg.Timestamp), 0))