lnwire+netann: update ChannelAnnouncement2 structure

Such that all fields are now TLV (including the signature).
This commit is contained in:
Elle Mouton
2024-10-11 16:26:47 +02:00
parent 5e7ca548aa
commit 1f71f14587
6 changed files with 292 additions and 109 deletions

View File

@@ -12,9 +12,6 @@ import (
// ChannelAnnouncement2 message is used to announce the existence of a taproot // ChannelAnnouncement2 message is used to announce the existence of a taproot
// channel between two peers in the network. // channel between two peers in the network.
type ChannelAnnouncement2 struct { type ChannelAnnouncement2 struct {
// Signature is a Schnorr signature over the TLV stream of the message.
Signature Sig
// ChainHash denotes the target chain that this channel was opened // ChainHash denotes the target chain that this channel was opened
// within. This value should be the genesis hash of the target chain. // within. This value should be the genesis hash of the target chain.
ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash] ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash]
@@ -59,74 +56,14 @@ type ChannelAnnouncement2 struct {
// the funding output is a pure 2-of-2 MuSig aggregate public key. // the funding output is a pure 2-of-2 MuSig aggregate public key.
MerkleRootHash tlv.OptionalRecordT[tlv.TlvType16, [32]byte] MerkleRootHash tlv.OptionalRecordT[tlv.TlvType16, [32]byte]
// ExtraOpaqueData is the set of data that was appended to this // Signature is a Schnorr signature over serialised signed-range TLV
// message, some of which we may not actually know how to iterate or // stream of the message.
// parse. By holding onto this data, we ensure that we're able to Signature tlv.RecordT[tlv.TlvType160, Sig]
// properly validate the set of signatures that cover these new fields,
// and ensure we're able to make upgrades to the network in a forwards
// compatible manner.
ExtraOpaqueData ExtraOpaqueData
}
// Decode deserializes a serialized AnnounceSignatures1 stored in the passed // Any extra fields in the signed range that we do not yet know about,
// io.Reader observing the specified protocol version. // but we need to keep them for signature validation and to produce a
// // valid message.
// This is part of the lnwire.Message interface. ExtraSignedFields
func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error {
err := ReadElement(r, &c.Signature)
if err != nil {
return err
}
c.Signature.ForceSchnorr()
return c.DecodeTLVRecords(r)
}
// DecodeTLVRecords decodes only the TLV section of the message.
func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error {
// First extract into extra opaque data.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var (
chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]()
btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]()
btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]()
merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]()
)
typeMap, err := tlvRecords.ExtractRecords(
&chainHash, &c.Features, &c.ShortChannelID, &c.Capacity,
&c.NodeID1, &c.NodeID2, &btcKey1, &btcKey2, &merkleRootHash,
)
if err != nil {
return err
}
// By default, the chain-hash is the bitcoin mainnet genesis block hash.
c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash
if _, ok := typeMap[c.ChainHash.TlvType()]; ok {
c.ChainHash.Val = chainHash.Val
}
if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok {
c.BitcoinKey1 = tlv.SomeRecordT(btcKey1)
}
if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok {
c.BitcoinKey2 = tlv.SomeRecordT(btcKey2)
}
if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok {
c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash)
}
if len(tlvRecords) != 0 {
c.ExtraOpaqueData = tlvRecords
}
return c.ExtraOpaqueData.ValidateTLV()
} }
// Encode serializes the target AnnounceSignatures1 into the passed io.Writer // Encode serializes the target AnnounceSignatures1 into the passed io.Writer
@@ -134,21 +71,27 @@ func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error {
// //
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error {
_, err := w.Write(c.Signature.RawBytes()) return EncodePureTLVMessage(c, w)
if err != nil {
return err
}
_, err = c.DataToSign()
if err != nil {
return err
}
return WriteBytes(w, c.ExtraOpaqueData)
} }
// DataToSign encodes the data to be signed into the ExtraOpaqueData member and // AllRecords returns all the TLV records for the message. This will include all
// returns it. // the records we know about along with any that we don't know about but that
func (c *ChannelAnnouncement2) DataToSign() ([]byte, error) { // fall in the signed TLV range.
//
// NOTE: this is part of the PureTLVMessage interface.
func (c *ChannelAnnouncement2) AllRecords() []tlv.Record {
recordProducers := append(
c.allNonSignatureRecordProducers(), &c.Signature,
)
return ProduceRecordsSorted(recordProducers...)
}
// allNonSignatureRecordProducers returns all the TLV record producers for the
// message except the signature record producer.
//
//nolint:ll
func (c *ChannelAnnouncement2) allNonSignatureRecordProducers() []tlv.RecordProducer {
// The chain-hash record is only included if it is _not_ equal to the // The chain-hash record is only included if it is _not_ equal to the
// bitcoin mainnet genisis block hash. // bitcoin mainnet genisis block hash.
var recordProducers []tlv.RecordProducer var recordProducers []tlv.RecordProducer
@@ -178,12 +121,126 @@ func (c *ChannelAnnouncement2) DataToSign() ([]byte, error) {
}, },
) )
err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) recordProducers = append(recordProducers, RecordsAsProducers(
tlv.MapToRecords(c.ExtraSignedFields),
)...)
return recordProducers
}
// Decode deserializes a serialized AnnounceSignatures1 stored in the passed
// io.Reader observing the specified protocol version.
//
// This is part of the lnwire.Message interface.
func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error {
var (
chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]()
btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]()
btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]()
merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]()
)
stream, err := tlv.NewStream(ProduceRecordsSorted(
&chainHash,
&c.Features,
&c.ShortChannelID,
&c.Capacity,
&c.NodeID1,
&c.NodeID2,
&btcKey1,
&btcKey2,
&merkleRootHash,
&c.Signature,
)...)
if err != nil { if err != nil {
return nil, err return err
}
c.Signature.Val.ForceSchnorr()
typeMap, err := stream.DecodeWithParsedTypesP2P(r)
if err != nil {
return err
} }
return c.ExtraOpaqueData, nil // By default, the chain-hash is the bitcoin mainnet genesis block hash.
c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash
if _, ok := typeMap[c.ChainHash.TlvType()]; ok {
c.ChainHash.Val = chainHash.Val
}
if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok {
c.BitcoinKey1 = tlv.SomeRecordT(btcKey1)
}
if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok {
c.BitcoinKey2 = tlv.SomeRecordT(btcKey2)
}
if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok {
c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash)
}
c.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap)
return nil
}
// DecodeNonSigTLVRecords decodes only the TLV section of the message.
func (c *ChannelAnnouncement2) DecodeNonSigTLVRecords(r io.Reader) error {
var (
chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]()
btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]()
btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]()
merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]()
)
stream, err := tlv.NewStream(ProduceRecordsSorted(
&chainHash,
&c.Features,
&c.ShortChannelID,
&c.Capacity,
&c.NodeID1,
&c.NodeID2,
&btcKey1,
&btcKey2,
&merkleRootHash,
)...)
if err != nil {
return err
}
typeMap, err := stream.DecodeWithParsedTypesP2P(r)
if err != nil {
return err
}
// By default, the chain-hash is the bitcoin mainnet genesis block hash.
c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash
if _, ok := typeMap[c.ChainHash.TlvType()]; ok {
c.ChainHash.Val = chainHash.Val
}
if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok {
c.BitcoinKey1 = tlv.SomeRecordT(btcKey1)
}
if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok {
c.BitcoinKey2 = tlv.SomeRecordT(btcKey2)
}
if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok {
c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash)
}
c.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap)
return nil
}
// EncodeAllNonSigFields encodes the entire message to the given writer but
// excludes the signature field.
func (c *ChannelAnnouncement2) EncodeAllNonSigFields(w io.Writer) error {
return EncodeRecordsTo(
w, ProduceRecordsSorted(c.allNonSignatureRecordProducers()...),
)
} }
// MsgType returns the integer uniquely identifying this message type on the // MsgType returns the integer uniquely identifying this message type on the
@@ -209,6 +266,10 @@ var _ Message = (*ChannelAnnouncement2)(nil)
// lnwire.SizeableMessage interface. // lnwire.SizeableMessage interface.
var _ SizeableMessage = (*ChannelAnnouncement2)(nil) var _ SizeableMessage = (*ChannelAnnouncement2)(nil)
// A compile time check to ensure ChannelAnnouncement2 implements the
// lnwire.PureTLVMessage interface.
var _ PureTLVMessage = (*ChannelAnnouncement2)(nil)
// Node1KeyBytes returns the bytes representing the public key of node 1 in the // Node1KeyBytes returns the bytes representing the public key of node 1 in the
// channel. // channel.
// //

View File

@@ -0,0 +1,103 @@
package lnwire
import (
"bytes"
"testing"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require"
)
// TestChanAnn2EncodeDecode tests the encoding and decoding of the
// ChannelAnnouncement2 message using hardcoded byte slices.
func TestChanAnn2EncodeDecode(t *testing.T) {
t.Parallel()
// We'll create a raw byte stream that represents a valid
// ChannelAnnouncement2 message with various known and unknown fields in
// the signed TLV ranges along with the signature in the unsigned range.
rawBytes := []byte{
// ChainHash record (optional, not mainnet).
0x00, // type.
0x20, // length.
0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1,
0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1,
0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1,
// Features record.
0x02, // type.
0x02, // length.
0x1, 0x2, // value.
// ShortChannelID record.
0x04, // type.
0x08, // length.
0x0, 0x0, 0x1, 0x0, 0x0, 0x2, 0x0, 0x3, // value.
// Unknown TLV record.
0x05, // type.
0x02, // length.
0xab, 0xcd, // value.
// Capacity record.
0x06, // type.
0x08, // length.
0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x86, 0xa0, // value: 100000.
// NodeID1 record.
0x08, // type.
0x21, // length.
0x2, 0x28, 0xf2, 0xaf, 0xa, 0xbe, 0x32, 0x24, 0x3, 0x48, 0xf,
0xb3, 0xee, 0x17, 0x2f, 0x7f, 0x16, 0x1, 0xe6, 0x7d, 0x1d, 0xa6,
0xca, 0xd4, 0xb, 0x54, 0xc4, 0x46, 0x8d, 0x48, 0x23, 0x6c, 0x39,
// NodeID2 record.
0x0a, // type.
0x21, // length.
0x3, 0x28, 0xf2, 0xaf, 0xa, 0xbe, 0x32, 0x24, 0x3, 0x48, 0xf,
0xb3, 0xee, 0x17, 0x2f, 0x7f, 0x16, 0x1, 0xe6, 0x7d, 0x1d, 0xa6,
0xca, 0xd4, 0xb, 0x54, 0xc4, 0x46, 0x8d, 0x48, 0x23, 0x6c, 0x39,
// Unknown TLV record.
0x6f, // type.
0x2, // length.
0x79, 0x79, // value.
// Signature.
0xa0, // type.
0x40, // length.
0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb,
0xc, 0xd, 0xe, 0xf, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16,
0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a,
0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34,
0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e,
0x3f, // value.
}
secondSignedRangeType := new(bytes.Buffer)
var buf [8]byte
err := tlv.WriteVarInt(
secondSignedRangeType, pureTLVSignedSecondRangeStart+1, &buf,
)
require.NoError(t, err)
rawBytes = append(rawBytes, secondSignedRangeType.Bytes()...) // type.
rawBytes = append(rawBytes, []byte{
0x02, // length.
0x79, 0x79, // value.
}...)
// Now, create a new empty message and decode the raw bytes into it.
msg := &ChannelAnnouncement2{}
r := bytes.NewReader(rawBytes)
err = msg.Decode(r, 0)
require.NoError(t, err)
// Next, encode the message back into a new byte buffer.
var b bytes.Buffer
err = msg.Encode(&b, 0)
require.NoError(t, err)
// The re-encoded bytes should be exactly the same as the original raw
// bytes.
require.Equal(t, rawBytes, b.Bytes())
}

View File

@@ -213,7 +213,6 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message {
copy(chainHashObj[:], chainHash[:]) copy(chainHashObj[:], chainHash[:])
msg := &ChannelAnnouncement2{ msg := &ChannelAnnouncement2{
Signature: RandSignature(t),
ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash](
chainHashObj, chainHashObj,
), ),
@@ -232,10 +231,16 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message {
NodeID2: tlv.NewPrimitiveRecord[tlv.TlvType10, [33]byte]( NodeID2: tlv.NewPrimitiveRecord[tlv.TlvType10, [33]byte](
nodeID2, nodeID2,
), ),
ExtraOpaqueData: RandExtraOpaqueData(t, nil), ExtraSignedFields: make(map[uint64][]byte),
} }
msg.Signature.ForceSchnorr() msg.Signature.Val = RandSignature(t)
msg.Signature.Val.ForceSchnorr()
randRecs, _ := RandSignedRangeRecords(t)
if len(randRecs) > 0 {
msg.ExtraSignedFields = ExtraSignedFields(randRecs)
}
// Randomly include optional fields // Randomly include optional fields
if rapid.Bool().Draw(t, "includeBitcoinKey1") { if rapid.Bool().Draw(t, "includeBitcoinKey1") {
@@ -411,7 +416,7 @@ func (a *ChannelUpdate1) RandTestMessage(t *rapid.T) Message {
// include an inbound fee, then we will also set the record in the // include an inbound fee, then we will also set the record in the
// extra opaque data. // extra opaque data.
var ( var (
customRecords, _ = RandCustomRecords(t, nil, false) customRecords, _ = RandCustomRecords(t, nil)
inboundFee tlv.OptionalRecordT[tlv.TlvType55555, Fee] inboundFee tlv.OptionalRecordT[tlv.TlvType55555, Fee]
) )
includeInboundFee := rapid.Bool().Draw(t, "includeInboundFee") includeInboundFee := rapid.Bool().Draw(t, "includeInboundFee")
@@ -728,7 +733,7 @@ var _ TestMessage = (*CommitSig)(nil)
// //
// This is part of the TestMessage interface. // This is part of the TestMessage interface.
func (c *CommitSig) RandTestMessage(t *rapid.T) Message { func (c *CommitSig) RandTestMessage(t *rapid.T) Message {
cr, _ := RandCustomRecords(t, nil, true) cr, _ := RandCustomRecords(t, nil)
sig := &CommitSig{ sig := &CommitSig{
ChanID: RandChannelID(t), ChanID: RandChannelID(t),
CommitSig: RandSignature(t), CommitSig: RandSignature(t),
@@ -1606,7 +1611,7 @@ func (s *Shutdown) RandTestMessage(t *rapid.T) Message {
shutdownNonce = SomeShutdownNonce(RandMusig2Nonce(t)) shutdownNonce = SomeShutdownNonce(RandMusig2Nonce(t))
} }
cr, _ := RandCustomRecords(t, nil, true) cr, _ := RandCustomRecords(t, nil)
return &Shutdown{ return &Shutdown{
ChannelID: RandChannelID(t), ChannelID: RandChannelID(t),
@@ -1663,7 +1668,7 @@ func (c *UpdateAddHTLC) RandTestMessage(t *rapid.T) Message {
numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords")
if numRecords > 0 { if numRecords > 0 {
msg.CustomRecords, _ = RandCustomRecords(t, nil, true) msg.CustomRecords, _ = RandCustomRecords(t, nil)
} }
// 50/50 chance to add a blinding point // 50/50 chance to add a blinding point
@@ -1744,7 +1749,7 @@ func (c *UpdateFulfillHTLC) RandTestMessage(t *rapid.T) Message {
PaymentPreimage: RandPaymentPreimage(t), PaymentPreimage: RandPaymentPreimage(t),
} }
cr, ignoreRecords := RandCustomRecords(t, nil, true) cr, ignoreRecords := RandCustomRecords(t, nil)
msg.CustomRecords = cr msg.CustomRecords = cr
randData := RandExtraOpaqueData(t, ignoreRecords) randData := RandExtraOpaqueData(t, ignoreRecords)

View File

@@ -198,23 +198,37 @@ func RandNetAddrs(t *rapid.T) []net.Addr {
} }
// RandCustomRecords generates random custom TLV records. // RandCustomRecords generates random custom TLV records.
func RandCustomRecords(t *rapid.T, func RandCustomRecords(t *rapid.T, ignoreRecords fn.Set[uint64]) (CustomRecords,
ignoreRecords fn.Set[uint64], fn.Set[uint64]) {
custom bool) (CustomRecords, fn.Set[uint64]) {
numRecords := rapid.IntRange(0, 5).Draw(t, "numCustomRecords") customRecords, set := RandTLVRecords(
t, ignoreRecords, MinCustomRecordsTlvType,
)
// Validate the custom records as a sanity check.
require.NoError(t, customRecords.Validate())
return customRecords, set
}
// RandSignedRangeRecords generates a random set of signed records in the
// second "signed" tlv range for pure TLV messages.
func RandSignedRangeRecords(t *rapid.T) (CustomRecords, fn.Set[uint64]) {
return RandTLVRecords(t, nil, pureTLVSignedSecondRangeStart)
}
// RandTLVRecords generates custom TLV records.
func RandTLVRecords(t *rapid.T, ignoreRecords fn.Set[uint64],
rangeStart int) (CustomRecords, fn.Set[uint64]) {
numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords")
customRecords := make(CustomRecords) customRecords := make(CustomRecords)
if numRecords == 0 { if numRecords == 0 {
return nil, nil return nil, nil
} }
rangeStart := 0 rangeStop := rangeStart + 30_000
rangeStop := int(CustomTypeStart)
if custom {
rangeStart = 70_000
rangeStop = 100_000
}
ignoreSet := fn.NewSet[uint64]() ignoreSet := fn.NewSet[uint64]()
for i := 0; i < numRecords; i++ { for i := 0; i < numRecords; i++ {
@@ -258,7 +272,7 @@ func RandExtraOpaqueData(t *rapid.T,
ignoreRecords fn.Set[uint64]) ExtraOpaqueData { ignoreRecords fn.Set[uint64]) ExtraOpaqueData {
// Make some random records. // Make some random records.
cRecords, _ := RandCustomRecords(t, ignoreRecords, false) cRecords, _ := RandTLVRecords(t, ignoreRecords, 0)
if cRecords == nil { if cRecords == nil {
return ExtraOpaqueData{} return ExtraOpaqueData{}
} }

View File

@@ -203,7 +203,7 @@ func validateChannelAnn2(a *lnwire.ChannelAnnouncement2,
return err return err
} }
sig, err := a.Signature.ToSignature() sig, err := a.Signature.Val.ToSignature()
if err != nil { if err != nil {
return err return err
} }
@@ -278,7 +278,7 @@ func validateChannelAnn2(a *lnwire.ChannelAnnouncement2,
func ChanAnn2DigestToSign(a *lnwire.ChannelAnnouncement2) (*chainhash.Hash, func ChanAnn2DigestToSign(a *lnwire.ChannelAnnouncement2) (*chainhash.Hash,
error) { error) {
data, err := a.DataToSign() data, err := lnwire.SerialiseFieldsToSign(a)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -159,7 +159,7 @@ func test4of4MuSig2ChanAnnouncement(t *testing.T) {
sig, err := lnwire.NewSigFromSignature(s) sig, err := lnwire.NewSigFromSignature(s)
require.NoError(t, err) require.NoError(t, err)
ann.Signature = sig ann.Signature.Val = sig
// Validate the announcement. // Validate the announcement.
require.NoError(t, ValidateChannelAnn(ann, nil)) require.NoError(t, ValidateChannelAnn(ann, nil))
@@ -259,7 +259,7 @@ func test3of3MuSig2ChanAnnouncement(t *testing.T) {
sig, err := lnwire.NewSigFromSignature(s) sig, err := lnwire.NewSigFromSignature(s)
require.NoError(t, err) require.NoError(t, err)
ann.Signature = sig ann.Signature.Val = sig
// Validate the announcement. // Validate the announcement.
require.NoError(t, ValidateChannelAnn(ann, fetchTx)) require.NoError(t, ValidateChannelAnn(ann, fetchTx))