From 2cfa89c719d98048bd8745b7dc54bf0362484dcf Mon Sep 17 00:00:00 2001 From: ffranr Date: Sat, 13 Apr 2024 12:29:41 +0100 Subject: [PATCH] lnwire: add custom records field to type `UpdateAddHtlc` - Introduce the field `CustomRecords` to the type `UpdateAddHtlc`. - Encode and decode the new field into the `ExtraData` field of the `update_add_htlc` wire message. --- lnwire/lnwire_test.go | 55 ++++++++++ lnwire/update_add_htlc.go | 42 +++++--- lnwire/update_add_htlc_test.go | 188 +++++++++++++++++++++++++++++++++ peer/brontide.go | 6 +- 4 files changed, 272 insertions(+), 19 deletions(-) create mode 100644 lnwire/update_add_htlc_test.go diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 122a99660..e7e765248 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -2,6 +2,7 @@ package lnwire import ( "bytes" + crand "crypto/rand" "encoding/binary" "encoding/hex" "fmt" @@ -134,6 +135,27 @@ func randPubKey() (*btcec.PublicKey, error) { return priv.PubKey(), nil } +// pubkeyFromHex parses a Bitcoin public key from a hex encoded string. +func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) { + pubKeyBytes, err := hex.DecodeString(keyHex) + if err != nil { + return nil, err + } + + return btcec.ParsePubKey(pubKeyBytes) +} + +// generateRandomBytes returns a slice of n random bytes. +func generateRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := crand.Read(b) + if err != nil { + return nil, err + } + + return b, nil +} + func randRawKey() ([33]byte, error) { var n [33]byte @@ -389,6 +411,37 @@ func TestEmptyMessageUnknownType(t *testing.T) { } } +// randCustomRecords generates a random set of custom records for testing. +func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords { + var ( + customRecords = CustomRecords{} + + // We'll generate a random number of records, between 1 and 10. + numRecords = r.Intn(9) + 1 + ) + + // For each record, we'll generate a random key and value. + for i := 0; i < numRecords; i++ { + // Keys must be equal to or greater than + // MinCustomRecordsTlvType. + keyOffset := uint64(r.Intn(100)) + key := MinCustomRecordsTlvType + keyOffset + + // Values are byte slices of any length. + value := make([]byte, r.Intn(100)) + _, err := r.Read(value) + require.NoError(t, err) + + customRecords[key] = value + } + + // Validate the custom records as a sanity check. + err := customRecords.Validate() + require.NoError(t, err) + + return customRecords +} + // TestLightningWireProtocol uses the testing/quick package to create a series // of fuzz tests to attempt to break a primary scenario which is implemented as // property based testing scenario. @@ -1369,6 +1422,8 @@ func TestLightningWireProtocol(t *testing.T) { _, err = r.Read(req.OnionBlob[:]) require.NoError(t, err) + req.CustomRecords = randCustomRecords(t, r) + // Generate a blinding point 50% of the time, since not // all update adds will use route blinding. if r.Int31()%2 == 0 { diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 8a40710e8..0a377e710 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -72,6 +72,11 @@ type UpdateAddHTLC struct { // next hop for this htlc. BlindingPoint BlindingPointRecord + // CustomRecords maps TLV types to byte slices, storing arbitrary data + // intended for inclusion in the ExtraData field of the UpdateAddHTLC + // message. + CustomRecords CustomRecords + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -92,6 +97,10 @@ var _ Message = (*UpdateAddHTLC)(nil) // // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { + // msgExtraData is a temporary variable used to read the message extra + // data field from the reader. + var msgExtraData ExtraOpaqueData + if err := ReadElements(r, &c.ChanID, &c.ID, @@ -99,26 +108,28 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { c.PaymentHash[:], &c.Expiry, c.OnionBlob[:], - &c.ExtraData, + &msgExtraData, ); err != nil { return err } + // Extract TLV records from the extra data field. blindingRecord := c.BlindingPoint.Zero() - tlvMap, err := c.ExtraData.ExtractRecords(&blindingRecord) + + customRecords, parsed, extraData, err := ParseAndExtractCustomRecords( + msgExtraData, &blindingRecord, + ) if err != nil { return err } - if val, ok := tlvMap[c.BlindingPoint.TlvType()]; ok && val == nil { + // Assign the parsed records back to the message. + if parsed.Contains(blindingRecord.TlvType()) { c.BlindingPoint = tlv.SomeRecordT(blindingRecord) } - // Set extra data to nil if we didn't parse anything out of it so that - // we can use assert.Equal in tests. - if len(tlvMap) == 0 { - c.ExtraData = nil - } + c.CustomRecords = customRecords + c.ExtraData = extraData return nil } @@ -154,19 +165,18 @@ func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error { // Only include blinding point in extra data if present. var records []tlv.RecordProducer + c.BlindingPoint.WhenSome( + func(b tlv.RecordT[BlindingPointTlvType, *btcec.PublicKey]) { + records = append(records, &b) + }, + ) - c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType, - *btcec.PublicKey]) { - - records = append(records, &b) - }) - - err := EncodeMessageExtraData(&c.ExtraData, records...) + extraData, err := MergeAndEncode(records, c.ExtraData, c.CustomRecords) if err != nil { return err } - return WriteBytes(w, c.ExtraData) + return WriteBytes(w, extraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/update_add_htlc_test.go b/lnwire/update_add_htlc_test.go new file mode 100644 index 000000000..53f8921bd --- /dev/null +++ b/lnwire/update_add_htlc_test.go @@ -0,0 +1,188 @@ +package lnwire + +import ( + "bytes" + "fmt" + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// testCase is a test case for the UpdateAddHTLC message. +type testCase struct { + // Msg is the message to be encoded and decoded. + Msg UpdateAddHTLC + + // ExpectEncodeError is a flag that indicates whether we expect the + // encoding of the message to fail. + ExpectEncodeError bool +} + +// generateTestCases generates a set of UpdateAddHTLC message test cases. +func generateTestCases(t *testing.T) []testCase { + // Firstly, we'll set basic values for the message fields. + // + // Generate random channel ID. + chanIDBytes, err := generateRandomBytes(32) + require.NoError(t, err) + + var chanID ChannelID + copy(chanID[:], chanIDBytes) + + // Generate random payment hash. + paymentHashBytes, err := generateRandomBytes(32) + require.NoError(t, err) + + var paymentHash [32]byte + copy(paymentHash[:], paymentHashBytes) + + // Generate random onion blob. + onionBlobBytes, err := generateRandomBytes(OnionPacketSize) + require.NoError(t, err) + + var onionBlob [OnionPacketSize]byte + copy(onionBlob[:], onionBlobBytes) + + // Define the blinding point. + blinding, err := pubkeyFromHex( + "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" + + "8236c39", + ) + require.NoError(t, err) + + blindingPoint := tlv.SomeRecordT( + tlv.NewPrimitiveRecord[BlindingPointTlvType](blinding), + ) + + // Define custom records. + recordKey1 := uint64(MinCustomRecordsTlvType + 1) + recordValue1, err := generateRandomBytes(10) + require.NoError(t, err) + + recordKey2 := uint64(MinCustomRecordsTlvType + 2) + recordValue2, err := generateRandomBytes(10) + require.NoError(t, err) + + customRecords := CustomRecords{ + recordKey1: recordValue1, + recordKey2: recordValue2, + } + + // Construct an instance of extra data that contains records with TLV + // types below the minimum custom records threshold and that lack + // corresponding fields in the message struct. Content should persist in + // the extra data field after encoding and decoding. + var ( + recordBytes45 = []byte("recordBytes45") + tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45]( + recordBytes45, + ) + + recordBytes55 = []byte("recordBytes55") + tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55]( + recordBytes55, + ) + ) + + var extraData ExtraOpaqueData + err = extraData.PackRecords( + []tlv.RecordProducer{&tlvRecord45, &tlvRecord55}..., + ) + require.NoError(t, err) + + invalidCustomRecords := CustomRecords{ + MinCustomRecordsTlvType - 1: recordValue1, + } + + return []testCase{ + { + Msg: UpdateAddHTLC{ + ChanID: chanID, + ID: 42, + Amount: MilliSatoshi(1000), + PaymentHash: paymentHash, + Expiry: 43, + OnionBlob: onionBlob, + BlindingPoint: blindingPoint, + CustomRecords: customRecords, + ExtraData: extraData, + }, + }, + // Add a test case where the blinding point field is not + // populated. + { + Msg: UpdateAddHTLC{ + ChanID: chanID, + ID: 42, + Amount: MilliSatoshi(1000), + PaymentHash: paymentHash, + Expiry: 43, + OnionBlob: onionBlob, + CustomRecords: customRecords, + }, + }, + // Add a test case where the custom records field is not + // populated. + { + Msg: UpdateAddHTLC{ + ChanID: chanID, + ID: 42, + Amount: MilliSatoshi(1000), + PaymentHash: paymentHash, + Expiry: 43, + OnionBlob: onionBlob, + BlindingPoint: blindingPoint, + }, + }, + // Add a case where the custom records are invalid. + { + Msg: UpdateAddHTLC{ + ChanID: chanID, + ID: 42, + Amount: MilliSatoshi(1000), + PaymentHash: paymentHash, + Expiry: 43, + OnionBlob: onionBlob, + BlindingPoint: blindingPoint, + CustomRecords: invalidCustomRecords, + }, + ExpectEncodeError: true, + }, + } +} + +// TestUpdateAddHtlcEncodeDecode tests UpdateAddHTLC message encoding and +// decoding for all supported field values. +func TestUpdateAddHtlcEncodeDecode(t *testing.T) { + t.Parallel() + + // Generate test cases. + testCases := generateTestCases(t) + + // Execute test cases. + for tcIdx, tc := range testCases { + t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) { + // Encode test case message. + var buf bytes.Buffer + err := tc.Msg.Encode(&buf, 0) + + // Check if we expect an encoding error. + if tc.ExpectEncodeError { + require.Error(t, err) + return + } + + require.NoError(t, err) + + // Decode the encoded message bytes message. + var actualMsg UpdateAddHTLC + decodeReader := bytes.NewReader(buf.Bytes()) + err = actualMsg.Decode(decodeReader, 0) + require.NoError(t, err) + + // Compare the two messages to ensure equality. + require.Equal(t, tc.Msg, actualMsg) + }) + } +} diff --git a/peer/brontide.go b/peer/brontide.go index bbb4aa287..2e6b8c07a 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -2193,9 +2193,9 @@ func messageSummary(msg lnwire.Message) string { ) return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, "+ - "hash=%x, blinding_point=%x", msg.ChanID, msg.ID, - msg.Amount, msg.Expiry, msg.PaymentHash[:], - blindingPoint) + "hash=%x, blinding_point=%x, custom_records=%v", + msg.ChanID, msg.ID, msg.Amount, msg.Expiry, + msg.PaymentHash[:], blindingPoint, msg.CustomRecords) case *lnwire.UpdateFailHTLC: return fmt.Sprintf("chan_id=%v, id=%v, reason=%x", msg.ChanID,