From 099f5566bce606572050b0186b26b662f7b4754b Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 29 May 2024 19:57:37 +0200 Subject: [PATCH] lnwire: add CustomRecords to shutdown message --- lnwire/lnwire_test.go | 9 +-- lnwire/shutdown.go | 48 +++++++------ lnwire/shutdown_test.go | 145 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 179 insertions(+), 23 deletions(-) create mode 100644 lnwire/shutdown_test.go diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 63742b4a5..374b85c2b 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -435,7 +435,7 @@ func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords { key := MinCustomRecordsTlvType + keyOffset // Values are byte slices of any length. - value := make([]byte, r.Intn(100)) + value := make([]byte, r.Intn(10)) _, err := r.Read(value) require.NoError(t, err) @@ -791,7 +791,6 @@ func TestLightningWireProtocol(t *testing.T) { req := Shutdown{ ChannelID: ChannelID(c), Address: shutdownAddr, - ExtraData: make([]byte, 0), } if r.Int31()%2 == 0 { @@ -953,12 +952,14 @@ func TestLightningWireProtocol(t *testing.T) { // Only create the slice if there will be any signatures // in it to prevent false positive test failures due to // an empty slice versus a nil slice. - numSigs := uint16(r.Int31n(1019)) + numSigs := uint16(r.Int31n(500)) if numSigs > 0 { req.HtlcSigs = make([]Sig, numSigs) } for i := 0; i < int(numSigs); i++ { - req.HtlcSigs[i], err = NewSigFromSignature(testSig) + req.HtlcSigs[i], err = NewSigFromSignature( + testSig, + ) if err != nil { t.Fatalf("unable to parse sig: %v", err) return diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index c5455651b..b9899fcfb 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -38,6 +38,11 @@ type Shutdown struct { // co-op sign offer. ShutdownNonce ShutdownNonceTLV + // CustomRecords maps TLV types to byte slices, storing arbitrary data + // intended for inclusion in the ExtraData field of the Shutdown + // message. + CustomRecords CustomRecords + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -56,7 +61,7 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown { // interface. var _ Message = (*Shutdown)(nil) -// Decode deserializes a serialized Shutdown stored in the passed io.Reader +// Decode deserializes a serialized Shutdown from the passed io.Reader, // observing the specified protocol version. // // This is part of the lnwire.Message interface. @@ -71,20 +76,23 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { return err } + // Extract TLV records from the extra data field. musigNonce := s.ShutdownNonce.Zero() - typeMap, err := tlvRecords.ExtractRecords(&musigNonce) + + customRecords, parsed, extraData, err := ParseAndExtractCustomRecords( + tlvRecords, &musigNonce, + ) if err != nil { return err } - // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[s.ShutdownNonce.TlvType()]; ok && val == nil { + // Assign the parsed records back to the message. + if _, ok := parsed[musigNonce.TlvType()]; ok { s.ShutdownNonce = tlv.SomeRecordT(musigNonce) } - if len(tlvRecords) != 0 { - s.ExtraData = tlvRecords - } + s.CustomRecords = customRecords + s.ExtraData = extraData return nil } @@ -94,17 +102,6 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { - recordProducers := make([]tlv.RecordProducer, 0, 1) - s.ShutdownNonce.WhenSome( - func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) { - recordProducers = append(recordProducers, &nonce) - }, - ) - err := EncodeMessageExtraData(&s.ExtraData, recordProducers...) - if err != nil { - return err - } - if err := WriteChannelID(w, s.ChannelID); err != nil { return err } @@ -113,7 +110,20 @@ func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, s.ExtraData) + // Only include nonce in extra data if present. + var records []tlv.RecordProducer + s.ShutdownNonce.WhenSome( + func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) { + records = append(records, &nonce) + }, + ) + + extraData, err := MergeAndEncode(records, s.ExtraData, s.CustomRecords) + if err != nil { + return err + } + + return WriteBytes(w, extraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/shutdown_test.go b/lnwire/shutdown_test.go new file mode 100644 index 000000000..7275efc9f --- /dev/null +++ b/lnwire/shutdown_test.go @@ -0,0 +1,145 @@ +package lnwire + +import ( + "bytes" + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// testCaseShutdown is a test case for the Shutdown message. +type testCaseShutdown struct { + // Msg is the message to be encoded and decoded. + Msg Shutdown + + // ExpectEncodeError is a flag that indicates whether we expect the + // encoding of the message to fail. + ExpectEncodeError bool +} + +// generateShutdownTestCases generates a set of Shutdown message test cases. +func generateShutdownTestCases(t *testing.T) []testCaseShutdown { + // Firstly, we'll set basic values for the message fields. + // + // Generate random channel ID. + chanIDBytes, err := generateRandomBytes(32) + require.NoError(t, err) + + var chanID ChannelID + copy(chanID[:], chanIDBytes) + + // Generate random payment preimage. + paymentPreimageBytes, err := generateRandomBytes(32) + require.NoError(t, err) + + var paymentPreimage [32]byte + copy(paymentPreimage[:], paymentPreimageBytes) + + deliveryAddr, err := generateRandomBytes(16) + require.NoError(t, err) + + // Define custom records. + recordKey1 := uint64(MinCustomRecordsTlvType + 1) + recordValue1, err := generateRandomBytes(10) + require.NoError(t, err) + + recordKey2 := uint64(MinCustomRecordsTlvType + 2) + recordValue2, err := generateRandomBytes(10) + require.NoError(t, err) + + customRecords := CustomRecords{ + recordKey1: recordValue1, + recordKey2: recordValue2, + } + + dummyPubKey, err := pubkeyFromHex( + "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" + + "8236c39", + ) + require.NoError(t, err) + + muSig2Nonce, err := musig2.GenNonces(musig2.WithPublicKey(dummyPubKey)) + require.NoError(t, err) + + // Construct an instance of extra data that contains records with TLV + // types below the minimum custom records threshold and that lack + // corresponding fields in the message struct. Content should persist in + // the extra data field after encoding and decoding. + var ( + recordBytes45 = []byte("recordBytes45") + tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45]( + recordBytes45, + ) + + recordBytes55 = []byte("recordBytes55") + tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55]( + recordBytes55, + ) + ) + + var extraData ExtraOpaqueData + err = extraData.PackRecords( + []tlv.RecordProducer{&tlvRecord45, &tlvRecord55}..., + ) + require.NoError(t, err) + + return []testCaseShutdown{ + { + Msg: Shutdown{ + ChannelID: chanID, + CustomRecords: customRecords, + ExtraData: extraData, + Address: deliveryAddr, + }, + }, + { + Msg: Shutdown{ + ChannelID: chanID, + CustomRecords: customRecords, + ExtraData: extraData, + Address: deliveryAddr, + ShutdownNonce: SomeShutdownNonce( + muSig2Nonce.PubNonce, + ), + }, + }, + } +} + +// TestShutdownEncodeDecode tests Shutdown message encoding and decoding for all +// supported field values. +func TestShutdownEncodeDecode(t *testing.T) { + t.Parallel() + + // Generate test cases. + testCases := generateShutdownTestCases(t) + + // Execute test cases. + for tcIdx, tc := range testCases { + t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) { + // Encode test case message. + var buf bytes.Buffer + err := tc.Msg.Encode(&buf, 0) + + // Check if we expect an encoding error. + if tc.ExpectEncodeError { + require.Error(t, err) + return + } + + require.NoError(t, err) + + // Decode the encoded message bytes message. + var actualMsg Shutdown + decodeReader := bytes.NewReader(buf.Bytes()) + err = actualMsg.Decode(decodeReader, 0) + require.NoError(t, err) + + // Compare the two messages to ensure equality. + require.Equal(t, tc.Msg, actualMsg) + }) + } +}