mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-29 17:19:33 +02:00
lnwire: add custom records field to type CommitSig
This commit is contained in:
parent
1e85c5054e
commit
bd84fd256e
@ -45,6 +45,10 @@ type CommitSig struct {
|
||||
// being signed for. In this case, the above Sig type MUST be blank.
|
||||
PartialSig OptPartialSigWithNonceTLV
|
||||
|
||||
// CustomRecords maps TLV types to byte slices, storing arbitrary data
|
||||
// intended for inclusion in the ExtraData field.
|
||||
CustomRecords CustomRecords
|
||||
|
||||
// ExtraData is the set of data that was appended to this message to
|
||||
// fill out the full maximum transport message size. These fields can
|
||||
// be used to specify optional data such as custom TLV fields.
|
||||
@ -53,9 +57,7 @@ type CommitSig struct {
|
||||
|
||||
// NewCommitSig creates a new empty CommitSig message.
|
||||
func NewCommitSig() *CommitSig {
|
||||
return &CommitSig{
|
||||
ExtraData: make([]byte, 0),
|
||||
}
|
||||
return &CommitSig{}
|
||||
}
|
||||
|
||||
// A compile time check to ensure CommitSig implements the lnwire.Message
|
||||
@ -67,34 +69,37 @@ var _ Message = (*CommitSig)(nil)
|
||||
//
|
||||
// This is part of the lnwire.Message interface.
|
||||
func (c *CommitSig) Decode(r io.Reader, pver uint32) error {
|
||||
// msgExtraData is a temporary variable used to read the message extra
|
||||
// data field from the reader.
|
||||
var msgExtraData ExtraOpaqueData
|
||||
|
||||
err := ReadElements(r,
|
||||
&c.ChanID,
|
||||
&c.CommitSig,
|
||||
&c.HtlcSigs,
|
||||
&msgExtraData,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var tlvRecords ExtraOpaqueData
|
||||
if err := ReadElements(r, &tlvRecords); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract TLV records from the extra data field.
|
||||
partialSig := c.PartialSig.Zero()
|
||||
typeMap, err := tlvRecords.ExtractRecords(&partialSig)
|
||||
|
||||
customRecords, parsed, extraData, err := ParseAndExtractCustomRecords(
|
||||
msgExtraData, &partialSig,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the corresponding TLV types if they were included in the stream.
|
||||
if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil {
|
||||
if _, ok := parsed[partialSig.TlvType()]; ok {
|
||||
c.PartialSig = tlv.SomeRecordT(partialSig)
|
||||
}
|
||||
|
||||
if len(tlvRecords) != 0 {
|
||||
c.ExtraData = tlvRecords
|
||||
}
|
||||
c.CustomRecords = customRecords
|
||||
c.ExtraData = extraData
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -108,7 +113,10 @@ func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error {
|
||||
c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) {
|
||||
recordProducers = append(recordProducers, &sig)
|
||||
})
|
||||
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
|
||||
|
||||
extraData, err := MergeAndEncode(
|
||||
recordProducers, c.ExtraData, c.CustomRecords,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -125,7 +133,7 @@ func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteBytes(w, c.ExtraData)
|
||||
return WriteBytes(w, extraData)
|
||||
}
|
||||
|
||||
// MsgType returns the integer uniquely identifying this message type on the
|
||||
|
168
lnwire/commit_sig_test.go
Normal file
168
lnwire/commit_sig_test.go
Normal file
@ -0,0 +1,168 @@
|
||||
package lnwire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testCase is a test case for the CommitSig message.
|
||||
type commitSigTestCase struct {
|
||||
// Msg is the message to be encoded and decoded.
|
||||
Msg CommitSig
|
||||
|
||||
// ExpectEncodeError is a flag that indicates whether we expect the
|
||||
// encoding of the message to fail.
|
||||
ExpectEncodeError bool
|
||||
}
|
||||
|
||||
// generateCommitSigTestCases generates a set of CommitSig message test cases.
|
||||
func generateCommitSigTestCases(t *testing.T) []commitSigTestCase {
|
||||
// Firstly, we'll set basic values for the message fields.
|
||||
//
|
||||
// Generate random channel ID.
|
||||
chanIDBytes, err := generateRandomBytes(32)
|
||||
require.NoError(t, err)
|
||||
|
||||
var chanID ChannelID
|
||||
copy(chanID[:], chanIDBytes)
|
||||
|
||||
// Generate random commit sig.
|
||||
commitSigBytes, err := generateRandomBytes(64)
|
||||
require.NoError(t, err)
|
||||
|
||||
sig, err := NewSigFromSchnorrRawSignature(commitSigBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
sigScalar := new(btcec.ModNScalar)
|
||||
sigScalar.SetByteSlice(sig.RawBytes())
|
||||
|
||||
var nonce [musig2.PubNonceSize]byte
|
||||
copy(nonce[:], commitSigBytes)
|
||||
|
||||
sigWithNonce := NewPartialSigWithNonce(nonce, *sigScalar)
|
||||
partialSig := MaybePartialSigWithNonce(sigWithNonce)
|
||||
|
||||
// Define custom records.
|
||||
recordKey1 := uint64(MinCustomRecordsTlvType + 1)
|
||||
recordValue1, err := generateRandomBytes(10)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordKey2 := uint64(MinCustomRecordsTlvType + 2)
|
||||
recordValue2, err := generateRandomBytes(10)
|
||||
require.NoError(t, err)
|
||||
|
||||
customRecords := CustomRecords{
|
||||
recordKey1: recordValue1,
|
||||
recordKey2: recordValue2,
|
||||
}
|
||||
|
||||
// Construct an instance of extra data that contains records with TLV
|
||||
// types below the minimum custom records threshold and that lack
|
||||
// corresponding fields in the message struct. Content should persist in
|
||||
// the extra data field after encoding and decoding.
|
||||
var (
|
||||
recordBytes45 = []byte("recordBytes45")
|
||||
tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
|
||||
recordBytes45,
|
||||
)
|
||||
|
||||
recordBytes55 = []byte("recordBytes55")
|
||||
tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
|
||||
recordBytes55,
|
||||
)
|
||||
)
|
||||
|
||||
var extraData ExtraOpaqueData
|
||||
err = extraData.PackRecords(
|
||||
[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
invalidCustomRecords := CustomRecords{
|
||||
MinCustomRecordsTlvType - 1: recordValue1,
|
||||
}
|
||||
|
||||
return []commitSigTestCase{
|
||||
{
|
||||
Msg: CommitSig{
|
||||
ChanID: chanID,
|
||||
CommitSig: sig,
|
||||
PartialSig: partialSig,
|
||||
CustomRecords: customRecords,
|
||||
ExtraData: extraData,
|
||||
},
|
||||
},
|
||||
// Add a test case where the blinding point field is not
|
||||
// populated.
|
||||
{
|
||||
Msg: CommitSig{
|
||||
ChanID: chanID,
|
||||
CommitSig: sig,
|
||||
CustomRecords: customRecords,
|
||||
},
|
||||
},
|
||||
// Add a test case where the custom records field is not
|
||||
// populated.
|
||||
{
|
||||
Msg: CommitSig{
|
||||
ChanID: chanID,
|
||||
CommitSig: sig,
|
||||
PartialSig: partialSig,
|
||||
},
|
||||
},
|
||||
// Add a case where the custom records are invalid.
|
||||
{
|
||||
Msg: CommitSig{
|
||||
ChanID: chanID,
|
||||
CommitSig: sig,
|
||||
PartialSig: partialSig,
|
||||
CustomRecords: invalidCustomRecords,
|
||||
},
|
||||
ExpectEncodeError: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommitSigEncodeDecode tests CommitSig message encoding and decoding for
|
||||
// all supported field values.
|
||||
func TestCommitSigEncodeDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Generate test cases.
|
||||
testCases := generateCommitSigTestCases(t)
|
||||
|
||||
// Execute test cases.
|
||||
for tcIdx, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
|
||||
// Encode test case message.
|
||||
var buf bytes.Buffer
|
||||
err := tc.Msg.Encode(&buf, 0)
|
||||
|
||||
// Check if we expect an encoding error.
|
||||
if tc.ExpectEncodeError {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode the encoded message bytes message.
|
||||
var actualMsg CommitSig
|
||||
decodeReader := bytes.NewReader(buf.Bytes())
|
||||
err = actualMsg.Decode(decodeReader, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The signature type isn't serialized.
|
||||
actualMsg.CommitSig.ForceSchnorr()
|
||||
|
||||
// Compare the two messages to ensure equality.
|
||||
require.Equal(t, tc.Msg, actualMsg)
|
||||
})
|
||||
}
|
||||
}
|
@ -945,6 +945,8 @@ func TestLightningWireProtocol(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
req.CustomRecords = randCustomRecords(t, r)
|
||||
|
||||
// 50/50 chance to attach a partial sig.
|
||||
if r.Int31()%2 == 0 {
|
||||
req.PartialSig = somePartialSigWithNonce(t, r)
|
||||
|
Loading…
x
Reference in New Issue
Block a user