mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-01 02:30:28 +02:00
lnwire: add CustomRecords to shutdown message
This commit is contained in:
parent
9a972e1b0c
commit
099f5566bc
@ -435,7 +435,7 @@ func randCustomRecords(t *testing.T, r *rand.Rand) CustomRecords {
|
|||||||
key := MinCustomRecordsTlvType + keyOffset
|
key := MinCustomRecordsTlvType + keyOffset
|
||||||
|
|
||||||
// Values are byte slices of any length.
|
// Values are byte slices of any length.
|
||||||
value := make([]byte, r.Intn(100))
|
value := make([]byte, r.Intn(10))
|
||||||
_, err := r.Read(value)
|
_, err := r.Read(value)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -791,7 +791,6 @@ func TestLightningWireProtocol(t *testing.T) {
|
|||||||
req := Shutdown{
|
req := Shutdown{
|
||||||
ChannelID: ChannelID(c),
|
ChannelID: ChannelID(c),
|
||||||
Address: shutdownAddr,
|
Address: shutdownAddr,
|
||||||
ExtraData: make([]byte, 0),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Int31()%2 == 0 {
|
if r.Int31()%2 == 0 {
|
||||||
@ -953,12 +952,14 @@ func TestLightningWireProtocol(t *testing.T) {
|
|||||||
// Only create the slice if there will be any signatures
|
// Only create the slice if there will be any signatures
|
||||||
// in it to prevent false positive test failures due to
|
// in it to prevent false positive test failures due to
|
||||||
// an empty slice versus a nil slice.
|
// an empty slice versus a nil slice.
|
||||||
numSigs := uint16(r.Int31n(1019))
|
numSigs := uint16(r.Int31n(500))
|
||||||
if numSigs > 0 {
|
if numSigs > 0 {
|
||||||
req.HtlcSigs = make([]Sig, numSigs)
|
req.HtlcSigs = make([]Sig, numSigs)
|
||||||
}
|
}
|
||||||
for i := 0; i < int(numSigs); i++ {
|
for i := 0; i < int(numSigs); i++ {
|
||||||
req.HtlcSigs[i], err = NewSigFromSignature(testSig)
|
req.HtlcSigs[i], err = NewSigFromSignature(
|
||||||
|
testSig,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to parse sig: %v", err)
|
t.Fatalf("unable to parse sig: %v", err)
|
||||||
return
|
return
|
||||||
|
@ -38,6 +38,11 @@ type Shutdown struct {
|
|||||||
// co-op sign offer.
|
// co-op sign offer.
|
||||||
ShutdownNonce ShutdownNonceTLV
|
ShutdownNonce ShutdownNonceTLV
|
||||||
|
|
||||||
|
// CustomRecords maps TLV types to byte slices, storing arbitrary data
|
||||||
|
// intended for inclusion in the ExtraData field of the Shutdown
|
||||||
|
// message.
|
||||||
|
CustomRecords CustomRecords
|
||||||
|
|
||||||
// ExtraData is the set of data that was appended to this message to
|
// ExtraData is the set of data that was appended to this message to
|
||||||
// fill out the full maximum transport message size. These fields can
|
// fill out the full maximum transport message size. These fields can
|
||||||
// be used to specify optional data such as custom TLV fields.
|
// be used to specify optional data such as custom TLV fields.
|
||||||
@ -56,7 +61,7 @@ func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown {
|
|||||||
// interface.
|
// interface.
|
||||||
var _ Message = (*Shutdown)(nil)
|
var _ Message = (*Shutdown)(nil)
|
||||||
|
|
||||||
// Decode deserializes a serialized Shutdown stored in the passed io.Reader
|
// Decode deserializes a serialized Shutdown from the passed io.Reader,
|
||||||
// observing the specified protocol version.
|
// observing the specified protocol version.
|
||||||
//
|
//
|
||||||
// This is part of the lnwire.Message interface.
|
// This is part of the lnwire.Message interface.
|
||||||
@ -71,20 +76,23 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract TLV records from the extra data field.
|
||||||
musigNonce := s.ShutdownNonce.Zero()
|
musigNonce := s.ShutdownNonce.Zero()
|
||||||
typeMap, err := tlvRecords.ExtractRecords(&musigNonce)
|
|
||||||
|
customRecords, parsed, extraData, err := ParseAndExtractCustomRecords(
|
||||||
|
tlvRecords, &musigNonce,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the corresponding TLV types if they were included in the stream.
|
// Assign the parsed records back to the message.
|
||||||
if val, ok := typeMap[s.ShutdownNonce.TlvType()]; ok && val == nil {
|
if _, ok := parsed[musigNonce.TlvType()]; ok {
|
||||||
s.ShutdownNonce = tlv.SomeRecordT(musigNonce)
|
s.ShutdownNonce = tlv.SomeRecordT(musigNonce)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tlvRecords) != 0 {
|
s.CustomRecords = customRecords
|
||||||
s.ExtraData = tlvRecords
|
s.ExtraData = extraData
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -94,17 +102,6 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
|
|||||||
//
|
//
|
||||||
// This is part of the lnwire.Message interface.
|
// This is part of the lnwire.Message interface.
|
||||||
func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error {
|
func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error {
|
||||||
recordProducers := make([]tlv.RecordProducer, 0, 1)
|
|
||||||
s.ShutdownNonce.WhenSome(
|
|
||||||
func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) {
|
|
||||||
recordProducers = append(recordProducers, &nonce)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
err := EncodeMessageExtraData(&s.ExtraData, recordProducers...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := WriteChannelID(w, s.ChannelID); err != nil {
|
if err := WriteChannelID(w, s.ChannelID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -113,7 +110,20 @@ func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return WriteBytes(w, s.ExtraData)
|
// Only include nonce in extra data if present.
|
||||||
|
var records []tlv.RecordProducer
|
||||||
|
s.ShutdownNonce.WhenSome(
|
||||||
|
func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) {
|
||||||
|
records = append(records, &nonce)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
extraData, err := MergeAndEncode(records, s.ExtraData, s.CustomRecords)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return WriteBytes(w, extraData)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MsgType returns the integer uniquely identifying this message type on the
|
// MsgType returns the integer uniquely identifying this message type on the
|
||||||
|
145
lnwire/shutdown_test.go
Normal file
145
lnwire/shutdown_test.go
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
package lnwire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
|
||||||
|
"github.com/lightningnetwork/lnd/tlv"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testCaseShutdown is a test case for the Shutdown message.
|
||||||
|
type testCaseShutdown struct {
|
||||||
|
// Msg is the message to be encoded and decoded.
|
||||||
|
Msg Shutdown
|
||||||
|
|
||||||
|
// ExpectEncodeError is a flag that indicates whether we expect the
|
||||||
|
// encoding of the message to fail.
|
||||||
|
ExpectEncodeError bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateShutdownTestCases generates a set of Shutdown message test cases.
|
||||||
|
func generateShutdownTestCases(t *testing.T) []testCaseShutdown {
|
||||||
|
// Firstly, we'll set basic values for the message fields.
|
||||||
|
//
|
||||||
|
// Generate random channel ID.
|
||||||
|
chanIDBytes, err := generateRandomBytes(32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var chanID ChannelID
|
||||||
|
copy(chanID[:], chanIDBytes)
|
||||||
|
|
||||||
|
// Generate random payment preimage.
|
||||||
|
paymentPreimageBytes, err := generateRandomBytes(32)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var paymentPreimage [32]byte
|
||||||
|
copy(paymentPreimage[:], paymentPreimageBytes)
|
||||||
|
|
||||||
|
deliveryAddr, err := generateRandomBytes(16)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Define custom records.
|
||||||
|
recordKey1 := uint64(MinCustomRecordsTlvType + 1)
|
||||||
|
recordValue1, err := generateRandomBytes(10)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
recordKey2 := uint64(MinCustomRecordsTlvType + 2)
|
||||||
|
recordValue2, err := generateRandomBytes(10)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
customRecords := CustomRecords{
|
||||||
|
recordKey1: recordValue1,
|
||||||
|
recordKey2: recordValue2,
|
||||||
|
}
|
||||||
|
|
||||||
|
dummyPubKey, err := pubkeyFromHex(
|
||||||
|
"0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" +
|
||||||
|
"8236c39",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
muSig2Nonce, err := musig2.GenNonces(musig2.WithPublicKey(dummyPubKey))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Construct an instance of extra data that contains records with TLV
|
||||||
|
// types below the minimum custom records threshold and that lack
|
||||||
|
// corresponding fields in the message struct. Content should persist in
|
||||||
|
// the extra data field after encoding and decoding.
|
||||||
|
var (
|
||||||
|
recordBytes45 = []byte("recordBytes45")
|
||||||
|
tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
|
||||||
|
recordBytes45,
|
||||||
|
)
|
||||||
|
|
||||||
|
recordBytes55 = []byte("recordBytes55")
|
||||||
|
tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
|
||||||
|
recordBytes55,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
var extraData ExtraOpaqueData
|
||||||
|
err = extraData.PackRecords(
|
||||||
|
[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return []testCaseShutdown{
|
||||||
|
{
|
||||||
|
Msg: Shutdown{
|
||||||
|
ChannelID: chanID,
|
||||||
|
CustomRecords: customRecords,
|
||||||
|
ExtraData: extraData,
|
||||||
|
Address: deliveryAddr,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Msg: Shutdown{
|
||||||
|
ChannelID: chanID,
|
||||||
|
CustomRecords: customRecords,
|
||||||
|
ExtraData: extraData,
|
||||||
|
Address: deliveryAddr,
|
||||||
|
ShutdownNonce: SomeShutdownNonce(
|
||||||
|
muSig2Nonce.PubNonce,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestShutdownEncodeDecode tests Shutdown message encoding and decoding for all
|
||||||
|
// supported field values.
|
||||||
|
func TestShutdownEncodeDecode(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Generate test cases.
|
||||||
|
testCases := generateShutdownTestCases(t)
|
||||||
|
|
||||||
|
// Execute test cases.
|
||||||
|
for tcIdx, tc := range testCases {
|
||||||
|
t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
|
||||||
|
// Encode test case message.
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := tc.Msg.Encode(&buf, 0)
|
||||||
|
|
||||||
|
// Check if we expect an encoding error.
|
||||||
|
if tc.ExpectEncodeError {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Decode the encoded message bytes message.
|
||||||
|
var actualMsg Shutdown
|
||||||
|
decodeReader := bytes.NewReader(buf.Bytes())
|
||||||
|
err = actualMsg.Decode(decodeReader, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Compare the two messages to ensure equality.
|
||||||
|
require.Equal(t, tc.Msg, actualMsg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user