From bd84fd256e46c409aa2f679437aa6602c11db377 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Mon, 2 Sep 2024 14:33:49 +0200 Subject: [PATCH] lnwire: add custom records field to type `CommitSig` --- lnwire/commit_sig.go | 38 +++++---- lnwire/commit_sig_test.go | 168 ++++++++++++++++++++++++++++++++++++++ lnwire/lnwire_test.go | 2 + 3 files changed, 193 insertions(+), 15 deletions(-) create mode 100644 lnwire/commit_sig_test.go diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index 7deb64ae1..3a475e71f 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -45,6 +45,10 @@ type CommitSig struct { // being signed for. In this case, the above Sig type MUST be blank. PartialSig OptPartialSigWithNonceTLV + // CustomRecords maps TLV types to byte slices, storing arbitrary data + // intended for inclusion in the ExtraData field. + CustomRecords CustomRecords + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -53,9 +57,7 @@ type CommitSig struct { // NewCommitSig creates a new empty CommitSig message. func NewCommitSig() *CommitSig { - return &CommitSig{ - ExtraData: make([]byte, 0), - } + return &CommitSig{} } // A compile time check to ensure CommitSig implements the lnwire.Message @@ -67,34 +69,37 @@ var _ Message = (*CommitSig)(nil) // // This is part of the lnwire.Message interface. func (c *CommitSig) Decode(r io.Reader, pver uint32) error { + // msgExtraData is a temporary variable used to read the message extra + // data field from the reader. + var msgExtraData ExtraOpaqueData + err := ReadElements(r, &c.ChanID, &c.CommitSig, &c.HtlcSigs, + &msgExtraData, ) if err != nil { return err } - var tlvRecords ExtraOpaqueData - if err := ReadElements(r, &tlvRecords); err != nil { - return err - } - + // Extract TLV records from the extra data field. partialSig := c.PartialSig.Zero() - typeMap, err := tlvRecords.ExtractRecords(&partialSig) + + customRecords, parsed, extraData, err := ParseAndExtractCustomRecords( + msgExtraData, &partialSig, + ) if err != nil { return err } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil { + if _, ok := parsed[partialSig.TlvType()]; ok { c.PartialSig = tlv.SomeRecordT(partialSig) } - if len(tlvRecords) != 0 { - c.ExtraData = tlvRecords - } + c.CustomRecords = customRecords + c.ExtraData = extraData return nil } @@ -108,7 +113,10 @@ func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error { c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) { recordProducers = append(recordProducers, &sig) }) - err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) + + extraData, err := MergeAndEncode( + recordProducers, c.ExtraData, c.CustomRecords, + ) if err != nil { return err } @@ -125,7 +133,7 @@ func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, c.ExtraData) + return WriteBytes(w, extraData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/commit_sig_test.go b/lnwire/commit_sig_test.go new file mode 100644 index 000000000..0772a2fb8 --- /dev/null +++ b/lnwire/commit_sig_test.go @@ -0,0 +1,168 @@ +package lnwire + +import ( + "bytes" + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// testCase is a test case for the CommitSig message. +type commitSigTestCase struct { + // Msg is the message to be encoded and decoded. + Msg CommitSig + + // ExpectEncodeError is a flag that indicates whether we expect the + // encoding of the message to fail. + ExpectEncodeError bool +} + +// generateCommitSigTestCases generates a set of CommitSig message test cases. +func generateCommitSigTestCases(t *testing.T) []commitSigTestCase { + // Firstly, we'll set basic values for the message fields. + // + // Generate random channel ID. + chanIDBytes, err := generateRandomBytes(32) + require.NoError(t, err) + + var chanID ChannelID + copy(chanID[:], chanIDBytes) + + // Generate random commit sig. + commitSigBytes, err := generateRandomBytes(64) + require.NoError(t, err) + + sig, err := NewSigFromSchnorrRawSignature(commitSigBytes) + require.NoError(t, err) + + sigScalar := new(btcec.ModNScalar) + sigScalar.SetByteSlice(sig.RawBytes()) + + var nonce [musig2.PubNonceSize]byte + copy(nonce[:], commitSigBytes) + + sigWithNonce := NewPartialSigWithNonce(nonce, *sigScalar) + partialSig := MaybePartialSigWithNonce(sigWithNonce) + + // Define custom records. + recordKey1 := uint64(MinCustomRecordsTlvType + 1) + recordValue1, err := generateRandomBytes(10) + require.NoError(t, err) + + recordKey2 := uint64(MinCustomRecordsTlvType + 2) + recordValue2, err := generateRandomBytes(10) + require.NoError(t, err) + + customRecords := CustomRecords{ + recordKey1: recordValue1, + recordKey2: recordValue2, + } + + // Construct an instance of extra data that contains records with TLV + // types below the minimum custom records threshold and that lack + // corresponding fields in the message struct. Content should persist in + // the extra data field after encoding and decoding. + var ( + recordBytes45 = []byte("recordBytes45") + tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45]( + recordBytes45, + ) + + recordBytes55 = []byte("recordBytes55") + tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55]( + recordBytes55, + ) + ) + + var extraData ExtraOpaqueData + err = extraData.PackRecords( + []tlv.RecordProducer{&tlvRecord45, &tlvRecord55}..., + ) + require.NoError(t, err) + + invalidCustomRecords := CustomRecords{ + MinCustomRecordsTlvType - 1: recordValue1, + } + + return []commitSigTestCase{ + { + Msg: CommitSig{ + ChanID: chanID, + CommitSig: sig, + PartialSig: partialSig, + CustomRecords: customRecords, + ExtraData: extraData, + }, + }, + // Add a test case where the blinding point field is not + // populated. + { + Msg: CommitSig{ + ChanID: chanID, + CommitSig: sig, + CustomRecords: customRecords, + }, + }, + // Add a test case where the custom records field is not + // populated. + { + Msg: CommitSig{ + ChanID: chanID, + CommitSig: sig, + PartialSig: partialSig, + }, + }, + // Add a case where the custom records are invalid. + { + Msg: CommitSig{ + ChanID: chanID, + CommitSig: sig, + PartialSig: partialSig, + CustomRecords: invalidCustomRecords, + }, + ExpectEncodeError: true, + }, + } +} + +// TestCommitSigEncodeDecode tests CommitSig message encoding and decoding for +// all supported field values. +func TestCommitSigEncodeDecode(t *testing.T) { + t.Parallel() + + // Generate test cases. + testCases := generateCommitSigTestCases(t) + + // Execute test cases. + for tcIdx, tc := range testCases { + t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) { + // Encode test case message. + var buf bytes.Buffer + err := tc.Msg.Encode(&buf, 0) + + // Check if we expect an encoding error. + if tc.ExpectEncodeError { + require.Error(t, err) + return + } + + require.NoError(t, err) + + // Decode the encoded message bytes message. + var actualMsg CommitSig + decodeReader := bytes.NewReader(buf.Bytes()) + err = actualMsg.Decode(decodeReader, 0) + require.NoError(t, err) + + // The signature type isn't serialized. + actualMsg.CommitSig.ForceSchnorr() + + // Compare the two messages to ensure equality. + require.Equal(t, tc.Msg, actualMsg) + }) + } +} diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 7eb434f45..e941962d5 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -945,6 +945,8 @@ func TestLightningWireProtocol(t *testing.T) { } } + req.CustomRecords = randCustomRecords(t, r) + // 50/50 chance to attach a partial sig. if r.Int31()%2 == 0 { req.PartialSig = somePartialSigWithNonce(t, r)