From 8d1059f41c635ca7803e148dd29ec0f29da20e2a Mon Sep 17 00:00:00 2001 From: ffranr Date: Fri, 3 May 2024 16:22:05 +0100 Subject: [PATCH] lnwire: add custom records field to type `UpdateFulfillHtlc` - Introduce the field `CustomRecords` to the type `UpdateFulfillHtlc`. - Encode and decode the new field into the `ExtraData` field of the `update_fulfill_htlc` wire message. - Empty `ExtraData` field is set to `nil`. --- htlcswitch/payment_result_test.go | 1 - lnwire/lnwire_test.go | 23 +++++ lnwire/update_fulfill_htlc.go | 36 +++++++- lnwire/update_fulfill_htlc_test.go | 129 +++++++++++++++++++++++++++++ peer/brontide.go | 5 +- 5 files changed, 188 insertions(+), 6 deletions(-) create mode 100644 lnwire/update_fulfill_htlc_test.go diff --git a/htlcswitch/payment_result_test.go b/htlcswitch/payment_result_test.go index 99e8074e5..664197f76 100644 --- a/htlcswitch/payment_result_test.go +++ b/htlcswitch/payment_result_test.go @@ -38,7 +38,6 @@ func TestNetworkResultSerialization(t *testing.T) { ChanID: chanID, ID: 2, PaymentPreimage: preimage, - ExtraData: make([]byte, 0), } fail := &lnwire.UpdateFailHTLC{ diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index e7e765248..7eb434f45 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1442,6 +1442,29 @@ func TestLightningWireProtocol(t *testing.T) { ) } + v[0] = reflect.ValueOf(*req) + }, + MsgUpdateFulfillHTLC: func(v []reflect.Value, r *rand.Rand) { + req := &UpdateFulfillHTLC{ + ID: r.Uint64(), + } + + _, err := r.Read(req.ChanID[:]) + require.NoError(t, err) + + _, err = r.Read(req.PaymentPreimage[:]) + require.NoError(t, err) + + req.CustomRecords = randCustomRecords(t, r) + + // Generate some random TLV records 50% of the time. + if r.Int31()%2 == 0 { + req.ExtraData = []byte{ + 0x01, 0x03, 1, 2, 3, + 0x02, 0x03, 4, 5, 6, + } + } + v[0] = reflect.ValueOf(*req) }, } diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 275a37c87..35aaa2ff5 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -23,6 +23,10 @@ type UpdateFulfillHTLC struct { // HTLC. PaymentPreimage [32]byte + // CustomRecords maps TLV types to byte slices, storing arbitrary data + // intended for inclusion in the ExtraData field. + CustomRecords CustomRecords + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -49,12 +53,31 @@ var _ Message = (*UpdateFulfillHTLC)(nil) // // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, + // msgExtraData is a temporary variable used to read the message extra + // data field from the reader. + var msgExtraData ExtraOpaqueData + + if err := ReadElements(r, &c.ChanID, &c.ID, c.PaymentPreimage[:], - &c.ExtraData, + &msgExtraData, + ); err != nil { + return err + } + + // Extract custom records from the extra data field. + customRecords, _, extraData, err := ParseAndExtractCustomRecords( + msgExtraData, ) + if err != nil { + return err + } + + c.CustomRecords = customRecords + c.ExtraData = extraData + + return nil } // Encode serializes the target UpdateFulfillHTLC into the passed io.Writer @@ -74,7 +97,14 @@ func (c *UpdateFulfillHTLC) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, c.ExtraData) + // Combine the custom records and the extra data, then encode the + // result as a byte slice. + extraData, err := MergeAndEncode(nil, c.ExtraData, c.CustomRecords) + if err != nil { + return err + } + + return WriteBytes(w, extraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/update_fulfill_htlc_test.go b/lnwire/update_fulfill_htlc_test.go new file mode 100644 index 000000000..e38b3a9b8 --- /dev/null +++ b/lnwire/update_fulfill_htlc_test.go @@ -0,0 +1,129 @@ +package lnwire + +import ( + "bytes" + "fmt" + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// testCaseUpdateFulfill is a test case for the UpdateFulfillHTLC message. +type testCaseUpdateFulfill struct { + // Msg is the message to be encoded and decoded. + Msg UpdateFulfillHTLC + + // ExpectEncodeError is a flag that indicates whether we expect the + // encoding of the message to fail. + ExpectEncodeError bool +} + +// generateTestCases generates a set of UpdateFulfillHTLC message test cases. +func generateUpdateFulfillTestCases(t *testing.T) []testCaseUpdateFulfill { + // Firstly, we'll set basic values for the message fields. + // + // Generate random channel ID. + chanIDBytes, err := generateRandomBytes(32) + require.NoError(t, err) + + var chanID ChannelID + copy(chanID[:], chanIDBytes) + + // Generate random payment preimage. + paymentPreimageBytes, err := generateRandomBytes(32) + require.NoError(t, err) + + var paymentPreimage [32]byte + copy(paymentPreimage[:], paymentPreimageBytes) + + // Define custom records. + recordKey1 := uint64(MinCustomRecordsTlvType + 1) + recordValue1, err := generateRandomBytes(10) + require.NoError(t, err) + + recordKey2 := uint64(MinCustomRecordsTlvType + 2) + recordValue2, err := generateRandomBytes(10) + require.NoError(t, err) + + customRecords := CustomRecords{ + recordKey1: recordValue1, + recordKey2: recordValue2, + } + + // Construct an instance of extra data that contains records with TLV + // types below the minimum custom records threshold and that lack + // corresponding fields in the message struct. Content should persist in + // the extra data field after encoding and decoding. + var ( + recordBytes45 = []byte("recordBytes45") + tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45]( + recordBytes45, + ) + + recordBytes55 = []byte("recordBytes55") + tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55]( + recordBytes55, + ) + ) + + var extraData ExtraOpaqueData + err = extraData.PackRecords( + []tlv.RecordProducer{&tlvRecord45, &tlvRecord55}..., + ) + require.NoError(t, err) + + return []testCaseUpdateFulfill{ + { + Msg: UpdateFulfillHTLC{ + ChanID: chanID, + ID: 42, + PaymentPreimage: paymentPreimage, + }, + }, + { + Msg: UpdateFulfillHTLC{ + ChanID: chanID, + ID: 42, + PaymentPreimage: paymentPreimage, + CustomRecords: customRecords, + ExtraData: extraData, + }, + }, + } +} + +// TestUpdateFulfillHtlcEncodeDecode tests UpdateFulfillHTLC message encoding +// and decoding for all supported field values. +func TestUpdateFulfillHtlcEncodeDecode(t *testing.T) { + t.Parallel() + + // Generate test cases. + testCases := generateUpdateFulfillTestCases(t) + + // Execute test cases. + for tcIdx, tc := range testCases { + t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) { + // Encode test case message. + var buf bytes.Buffer + err := tc.Msg.Encode(&buf, 0) + + // Check if we expect an encoding error. + if tc.ExpectEncodeError { + require.Error(t, err) + return + } + + require.NoError(t, err) + + // Decode the encoded message bytes message. + var actualMsg UpdateFulfillHTLC + decodeReader := bytes.NewReader(buf.Bytes()) + err = actualMsg.Decode(decodeReader, 0) + require.NoError(t, err) + + // Compare the two messages to ensure equality. + require.Equal(t, tc.Msg, actualMsg) + }) + } +} diff --git a/peer/brontide.go b/peer/brontide.go index 2e6b8c07a..3223e7f4b 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -2202,8 +2202,9 @@ func messageSummary(msg lnwire.Message) string { msg.ID, msg.Reason) case *lnwire.UpdateFulfillHTLC: - return fmt.Sprintf("chan_id=%v, id=%v, pre_image=%x", - msg.ChanID, msg.ID, msg.PaymentPreimage[:]) + return fmt.Sprintf("chan_id=%v, id=%v, pre_image=%x, "+ + "custom_records=%v", msg.ChanID, msg.ID, + msg.PaymentPreimage[:], msg.CustomRecords) case *lnwire.CommitSig: return fmt.Sprintf("chan_id=%v, num_htlcs=%v", msg.ChanID,