mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-10-09 21:03:33 +02:00
lnwire: add custom records field to type UpdateAddHtlc
- Introduce the field `CustomRecords` to the type `UpdateAddHtlc`. - Encode and decode the new field into the `ExtraData` field of the `update_add_htlc` wire message.
This commit is contained in:
@@ -2,6 +2,7 @@ package lnwire
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
crand "crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -134,6 +135,27 @@ func randPubKey() (*btcec.PublicKey, error) {
|
|||||||
return priv.PubKey(), nil
|
return priv.PubKey(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pubkeyFromHex parses a Bitcoin public key from a hex encoded string.
|
||||||
|
func pubkeyFromHex(keyHex string) (*btcec.PublicKey, error) {
|
||||||
|
pubKeyBytes, err := hex.DecodeString(keyHex)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return btcec.ParsePubKey(pubKeyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomBytes returns a slice of n random bytes.
|
||||||
|
func generateRandomBytes(n int) ([]byte, error) {
|
||||||
|
b := make([]byte, n)
|
||||||
|
_, err := crand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
func randRawKey() ([33]byte, error) {
|
func randRawKey() ([33]byte, error) {
|
||||||
var n [33]byte
|
var n [33]byte
|
||||||
|
|
||||||
@@ -389,6 +411,37 @@ func TestEmptyMessageUnknownType(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// randCustomRecords generates a random set of custom records for testing.
|
||||||
|
func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords {
|
||||||
|
var (
|
||||||
|
customRecords = CustomRecords{}
|
||||||
|
|
||||||
|
// We'll generate a random number of records, between 1 and 10.
|
||||||
|
numRecords = r.Intn(9) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// For each record, we'll generate a random key and value.
|
||||||
|
for i := 0; i < numRecords; i++ {
|
||||||
|
// Keys must be equal to or greater than
|
||||||
|
// MinCustomRecordsTlvType.
|
||||||
|
keyOffset := uint64(r.Intn(100))
|
||||||
|
key := MinCustomRecordsTlvType + keyOffset
|
||||||
|
|
||||||
|
// Values are byte slices of any length.
|
||||||
|
value := make([]byte, r.Intn(100))
|
||||||
|
_, err := r.Read(value)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
customRecords[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the custom records as a sanity check.
|
||||||
|
err := customRecords.Validate()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return customRecords
|
||||||
|
}
|
||||||
|
|
||||||
// TestLightningWireProtocol uses the testing/quick package to create a series
|
// TestLightningWireProtocol uses the testing/quick package to create a series
|
||||||
// of fuzz tests to attempt to break a primary scenario which is implemented as
|
// of fuzz tests to attempt to break a primary scenario which is implemented as
|
||||||
// property based testing scenario.
|
// property based testing scenario.
|
||||||
@@ -1369,6 +1422,8 @@ func TestLightningWireProtocol(t *testing.T) {
|
|||||||
_, err = r.Read(req.OnionBlob[:])
|
_, err = r.Read(req.OnionBlob[:])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req.CustomRecords = randCustomRecords(t, r)
|
||||||
|
|
||||||
// Generate a blinding point 50% of the time, since not
|
// Generate a blinding point 50% of the time, since not
|
||||||
// all update adds will use route blinding.
|
// all update adds will use route blinding.
|
||||||
if r.Int31()%2 == 0 {
|
if r.Int31()%2 == 0 {
|
||||||
|
@@ -72,6 +72,11 @@ type UpdateAddHTLC struct {
|
|||||||
// next hop for this htlc.
|
// next hop for this htlc.
|
||||||
BlindingPoint BlindingPointRecord
|
BlindingPoint BlindingPointRecord
|
||||||
|
|
||||||
|
// CustomRecords maps TLV types to byte slices, storing arbitrary data
|
||||||
|
// intended for inclusion in the ExtraData field of the UpdateAddHTLC
|
||||||
|
// message.
|
||||||
|
CustomRecords CustomRecords
|
||||||
|
|
||||||
// ExtraData is the set of data that was appended to this message to
|
// ExtraData is the set of data that was appended to this message to
|
||||||
// fill out the full maximum transport message size. These fields can
|
// fill out the full maximum transport message size. These fields can
|
||||||
// be used to specify optional data such as custom TLV fields.
|
// be used to specify optional data such as custom TLV fields.
|
||||||
@@ -92,6 +97,10 @@ var _ Message = (*UpdateAddHTLC)(nil)
|
|||||||
//
|
//
|
||||||
// This is part of the lnwire.Message interface.
|
// This is part of the lnwire.Message interface.
|
||||||
func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
|
func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
|
||||||
|
// msgExtraData is a temporary variable used to read the message extra
|
||||||
|
// data field from the reader.
|
||||||
|
var msgExtraData ExtraOpaqueData
|
||||||
|
|
||||||
if err := ReadElements(r,
|
if err := ReadElements(r,
|
||||||
&c.ChanID,
|
&c.ChanID,
|
||||||
&c.ID,
|
&c.ID,
|
||||||
@@ -99,26 +108,28 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error {
|
|||||||
c.PaymentHash[:],
|
c.PaymentHash[:],
|
||||||
&c.Expiry,
|
&c.Expiry,
|
||||||
c.OnionBlob[:],
|
c.OnionBlob[:],
|
||||||
&c.ExtraData,
|
&msgExtraData,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract TLV records from the extra data field.
|
||||||
blindingRecord := c.BlindingPoint.Zero()
|
blindingRecord := c.BlindingPoint.Zero()
|
||||||
tlvMap, err := c.ExtraData.ExtractRecords(&blindingRecord)
|
|
||||||
|
customRecords, parsed, extraData, err := ParseAndExtractCustomRecords(
|
||||||
|
msgExtraData, &blindingRecord,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := tlvMap[c.BlindingPoint.TlvType()]; ok && val == nil {
|
// Assign the parsed records back to the message.
|
||||||
|
if parsed.Contains(blindingRecord.TlvType()) {
|
||||||
c.BlindingPoint = tlv.SomeRecordT(blindingRecord)
|
c.BlindingPoint = tlv.SomeRecordT(blindingRecord)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set extra data to nil if we didn't parse anything out of it so that
|
c.CustomRecords = customRecords
|
||||||
// we can use assert.Equal in tests.
|
c.ExtraData = extraData
|
||||||
if len(tlvMap) == 0 {
|
|
||||||
c.ExtraData = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -154,19 +165,18 @@ func (c *UpdateAddHTLC) Encode(w *bytes.Buffer, pver uint32) error {
|
|||||||
|
|
||||||
// Only include blinding point in extra data if present.
|
// Only include blinding point in extra data if present.
|
||||||
var records []tlv.RecordProducer
|
var records []tlv.RecordProducer
|
||||||
|
c.BlindingPoint.WhenSome(
|
||||||
c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType,
|
func(b tlv.RecordT[BlindingPointTlvType, *btcec.PublicKey]) {
|
||||||
*btcec.PublicKey]) {
|
|
||||||
|
|
||||||
records = append(records, &b)
|
records = append(records, &b)
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
err := EncodeMessageExtraData(&c.ExtraData, records...)
|
extraData, err := MergeAndEncode(records, c.ExtraData, c.CustomRecords)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return WriteBytes(w, c.ExtraData)
|
return WriteBytes(w, extraData)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MsgType returns the integer uniquely identifying this message type on the
|
// MsgType returns the integer uniquely identifying this message type on the
|
||||||
|
188
lnwire/update_add_htlc_test.go
Normal file
188
lnwire/update_add_htlc_test.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package lnwire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testCase is a test case for the UpdateAddHTLC message.
|
||||||
|
type testCase struct {
|
||||||
|
// Msg is the message to be encoded and decoded.
|
||||||
|
Msg UpdateAddHTLC
|
||||||
|
|
||||||
|
// ExpectEncodeError is a flag that indicates whether we expect the
|
||||||
|
// encoding of the message to fail.
|
||||||
|
ExpectEncodeError bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTestCases generates a set of UpdateAddHTLC message test cases.
|
||||||
|
func generateTestCases(t *testing.T) []testCase {
|
||||||
|
// Firstly, we'll set basic values for the message fields.
|
||||||
|
//
|
||||||
|
// Generate random channel ID.
|
||||||
|
chanIDBytes, err := generateRandomBytes(32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var chanID ChannelID
|
||||||
|
copy(chanID[:], chanIDBytes)
|
||||||
|
|
||||||
|
// Generate random payment hash.
|
||||||
|
paymentHashBytes, err := generateRandomBytes(32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var paymentHash [32]byte
|
||||||
|
copy(paymentHash[:], paymentHashBytes)
|
||||||
|
|
||||||
|
// Generate random onion blob.
|
||||||
|
onionBlobBytes, err := generateRandomBytes(OnionPacketSize)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var onionBlob [OnionPacketSize]byte
|
||||||
|
copy(onionBlob[:], onionBlobBytes)
|
||||||
|
|
||||||
|
// Define the blinding point.
|
||||||
|
blinding, err := pubkeyFromHex(
|
||||||
|
"0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" +
|
||||||
|
"8236c39",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
blindingPoint := tlv.SomeRecordT(
|
||||||
|
tlv.NewPrimitiveRecord[BlindingPointTlvType](blinding),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Define custom records.
|
||||||
|
recordKey1 := uint64(MinCustomRecordsTlvType + 1)
|
||||||
|
recordValue1, err := generateRandomBytes(10)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recordKey2 := uint64(MinCustomRecordsTlvType + 2)
|
||||||
|
recordValue2, err := generateRandomBytes(10)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
customRecords := CustomRecords{
|
||||||
|
recordKey1: recordValue1,
|
||||||
|
recordKey2: recordValue2,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct an instance of extra data that contains records with TLV
|
||||||
|
// types below the minimum custom records threshold and that lack
|
||||||
|
// corresponding fields in the message struct. Content should persist in
|
||||||
|
// the extra data field after encoding and decoding.
|
||||||
|
var (
|
||||||
|
recordBytes45 = []byte("recordBytes45")
|
||||||
|
tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
|
||||||
|
recordBytes45,
|
||||||
|
)
|
||||||
|
|
||||||
|
recordBytes55 = []byte("recordBytes55")
|
||||||
|
tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
|
||||||
|
recordBytes55,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
var extraData ExtraOpaqueData
|
||||||
|
err = extraData.PackRecords(
|
||||||
|
[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
invalidCustomRecords := CustomRecords{
|
||||||
|
MinCustomRecordsTlvType - 1: recordValue1,
|
||||||
|
}
|
||||||
|
|
||||||
|
return []testCase{
|
||||||
|
{
|
||||||
|
Msg: UpdateAddHTLC{
|
||||||
|
ChanID: chanID,
|
||||||
|
ID: 42,
|
||||||
|
Amount: MilliSatoshi(1000),
|
||||||
|
PaymentHash: paymentHash,
|
||||||
|
Expiry: 43,
|
||||||
|
OnionBlob: onionBlob,
|
||||||
|
BlindingPoint: blindingPoint,
|
||||||
|
CustomRecords: customRecords,
|
||||||
|
ExtraData: extraData,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Add a test case where the blinding point field is not
|
||||||
|
// populated.
|
||||||
|
{
|
||||||
|
Msg: UpdateAddHTLC{
|
||||||
|
ChanID: chanID,
|
||||||
|
ID: 42,
|
||||||
|
Amount: MilliSatoshi(1000),
|
||||||
|
PaymentHash: paymentHash,
|
||||||
|
Expiry: 43,
|
||||||
|
OnionBlob: onionBlob,
|
||||||
|
CustomRecords: customRecords,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Add a test case where the custom records field is not
|
||||||
|
// populated.
|
||||||
|
{
|
||||||
|
Msg: UpdateAddHTLC{
|
||||||
|
ChanID: chanID,
|
||||||
|
ID: 42,
|
||||||
|
Amount: MilliSatoshi(1000),
|
||||||
|
PaymentHash: paymentHash,
|
||||||
|
Expiry: 43,
|
||||||
|
OnionBlob: onionBlob,
|
||||||
|
BlindingPoint: blindingPoint,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Add a case where the custom records are invalid.
|
||||||
|
{
|
||||||
|
Msg: UpdateAddHTLC{
|
||||||
|
ChanID: chanID,
|
||||||
|
ID: 42,
|
||||||
|
Amount: MilliSatoshi(1000),
|
||||||
|
PaymentHash: paymentHash,
|
||||||
|
Expiry: 43,
|
||||||
|
OnionBlob: onionBlob,
|
||||||
|
BlindingPoint: blindingPoint,
|
||||||
|
CustomRecords: invalidCustomRecords,
|
||||||
|
},
|
||||||
|
ExpectEncodeError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpdateAddHtlcEncodeDecode tests UpdateAddHTLC message encoding and
|
||||||
|
// decoding for all supported field values.
|
||||||
|
func TestUpdateAddHtlcEncodeDecode(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Generate test cases.
|
||||||
|
testCases := generateTestCases(t)
|
||||||
|
|
||||||
|
// Execute test cases.
|
||||||
|
for tcIdx, tc := range testCases {
|
||||||
|
t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
|
||||||
|
// Encode test case message.
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := tc.Msg.Encode(&buf, 0)
|
||||||
|
|
||||||
|
// Check if we expect an encoding error.
|
||||||
|
if tc.ExpectEncodeError {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Decode the encoded message bytes message.
|
||||||
|
var actualMsg UpdateAddHTLC
|
||||||
|
decodeReader := bytes.NewReader(buf.Bytes())
|
||||||
|
err = actualMsg.Decode(decodeReader, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Compare the two messages to ensure equality.
|
||||||
|
require.Equal(t, tc.Msg, actualMsg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@@ -2193,9 +2193,9 @@ func messageSummary(msg lnwire.Message) string {
|
|||||||
)
|
)
|
||||||
|
|
||||||
return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, "+
|
return fmt.Sprintf("chan_id=%v, id=%v, amt=%v, expiry=%v, "+
|
||||||
"hash=%x, blinding_point=%x", msg.ChanID, msg.ID,
|
"hash=%x, blinding_point=%x, custom_records=%v",
|
||||||
msg.Amount, msg.Expiry, msg.PaymentHash[:],
|
msg.ChanID, msg.ID, msg.Amount, msg.Expiry,
|
||||||
blindingPoint)
|
msg.PaymentHash[:], blindingPoint, msg.CustomRecords)
|
||||||
|
|
||||||
case *lnwire.UpdateFailHTLC:
|
case *lnwire.UpdateFailHTLC:
|
||||||
return fmt.Sprintf("chan_id=%v, id=%v, reason=%x", msg.ChanID,
|
return fmt.Sprintf("chan_id=%v, id=%v, reason=%x", msg.ChanID,
|
||||||
|
Reference in New Issue
Block a user