diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index bdcb95ce8..912a068bd 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -33,6 +33,16 @@ type ChannelReady struct { // to accept a new commitment state transition. NextLocalNonce OptMusig2NonceTLV + // AnnouncementNodeNonce is an optional field that stores a public + // nonce that will be used along with the node's ID key during signing + // of the ChannelAnnouncement2 message. + AnnouncementNodeNonce tlv.OptionalRecordT[tlv.TlvType0, Musig2Nonce] + + // AnnouncementBitcoinNonce is an optional field that stores a public + // nonce that will be used along with the node's bitcoin key during + // signing of the ChannelAnnouncement2 message. + AnnouncementBitcoinNonce tlv.OptionalRecordT[tlv.TlvType2, Musig2Nonce] + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -78,9 +88,11 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { var ( aliasScid ShortChannelID localNonce = c.NextLocalNonce.Zero() + nodeNonce = tlv.ZeroRecordT[tlv.TlvType0, Musig2Nonce]() + btcNonce = tlv.ZeroRecordT[tlv.TlvType2, Musig2Nonce]() ) typeMap, err := tlvRecords.ExtractRecords( - &aliasScid, &localNonce, + &btcNonce, &aliasScid, &nodeNonce, &localNonce, ) if err != nil { return err @@ -94,6 +106,14 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { if val, ok := typeMap[c.NextLocalNonce.TlvType()]; ok && val == nil { c.NextLocalNonce = tlv.SomeRecordT(localNonce) } + val, ok := typeMap[c.AnnouncementBitcoinNonce.TlvType()] + if ok && val == nil { + c.AnnouncementBitcoinNonce = tlv.SomeRecordT(btcNonce) + } + val, ok = typeMap[c.AnnouncementNodeNonce.TlvType()] + if ok && val == nil { + c.AnnouncementNodeNonce = tlv.SomeRecordT(nodeNonce) + } if len(tlvRecords) != 0 { c.ExtraData = tlvRecords @@ -117,13 +137,24 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { } // We'll only encode the AliasScid in a TLV segment if it exists. - recordProducers := make([]tlv.RecordProducer, 0, 2) + recordProducers := make([]tlv.RecordProducer, 0, 4) if c.AliasScid != nil { recordProducers = append(recordProducers, c.AliasScid) } c.NextLocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { recordProducers = append(recordProducers, &localNonce) }) + c.AnnouncementBitcoinNonce.WhenSome( + func(nonce tlv.RecordT[tlv.TlvType2, Musig2Nonce]) { + recordProducers = append(recordProducers, &nonce) + }, + ) + c.AnnouncementNodeNonce.WhenSome( + func(nonce tlv.RecordT[tlv.TlvType0, Musig2Nonce]) { + recordProducers = append(recordProducers, &nonce) + }, + ) + err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) if err != nil { return err diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 1a87e3edb..2a347e642 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -677,18 +677,13 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgChannelReady: func(v []reflect.Value, r *rand.Rand) { var c [32]byte - if _, err := r.Read(c[:]); err != nil { - t.Fatalf("unable to generate chan id: %v", err) - return - } + _, err := r.Read(c[:]) + require.NoError(t, err) pubKey, err := randPubKey() - if err != nil { - t.Fatalf("unable to generate key: %v", err) - return - } + require.NoError(t, err) - req := NewChannelReady(ChannelID(c), pubKey) + req := NewChannelReady(c, pubKey) if r.Int31()%2 == 0 { scid := NewShortChanIDFromInt(uint64(r.Int63())) @@ -698,6 +693,24 @@ func TestLightningWireProtocol(t *testing.T) { req.NextLocalNonce = someLocalNonce[NonceRecordTypeT](r) } + if r.Int31()%2 == 0 { + nodeNonce := tlv.ZeroRecordT[ + tlv.TlvType0, Musig2Nonce, + ]() + nodeNonce.Val = randLocalNonce(r) + req.AnnouncementNodeNonce = tlv.SomeRecordT( + nodeNonce, + ) + + btcNonce := tlv.ZeroRecordT[ + tlv.TlvType2, Musig2Nonce, + ]() + btcNonce.Val = randLocalNonce(r) + req.AnnouncementBitcoinNonce = tlv.SomeRecordT( + btcNonce, + ) + } + v[0] = reflect.ValueOf(*req) }, MsgShutdown: func(v []reflect.Value, r *rand.Rand) {