diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 0e065dc3d..bebc19760 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -13,7 +13,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/neutrino/cache" "github.com/lightninglabs/neutrino/cache/lru" @@ -192,14 +194,9 @@ type PinnedSyncers map[route.Vertex]struct{} // Config defines the configuration for the service. ALL elements within the // configuration MUST be non-nil for the service to carry out its duties. type Config struct { - // ChainHash is a hash that indicates which resident chain of the - // AuthenticatedGossiper. Any announcements that don't match this - // chain hash will be ignored. - // - // TODO(roasbeef): eventually make into map so can de-multiplex - // incoming announcements - // * also need to do same for Notifier - ChainHash chainhash.Hash + // ChainParams holds the chain parameters for the active network this + // node is participating on. + ChainParams *chaincfg.Params // Graph is the subsystem which is responsible for managing the // topology of lightning network. After incoming channel, node, channel @@ -599,7 +596,7 @@ func New(cfg Config, selfKeyDesc *keychain.KeyDescriptor) *AuthenticatedGossiper gossiper.vb = NewValidationBarrier(1000, gossiper.quit) gossiper.syncMgr = newSyncManager(&SyncManagerCfg{ - ChainHash: cfg.ChainHash, + ChainHash: *cfg.ChainParams.GenesisHash, ChanSeries: cfg.ChanSeries, RotateTicker: cfg.RotateTicker, HistoricalSyncTicker: cfg.HistoricalSyncTicker, @@ -2037,10 +2034,29 @@ func (d *AuthenticatedGossiper) processRejectedEdge(_ context.Context, } // fetchPKScript fetches the output script for the given SCID. -func (d *AuthenticatedGossiper) fetchPKScript(chanID *lnwire.ShortChannelID) ( - []byte, error) { +func (d *AuthenticatedGossiper) fetchPKScript(chanID lnwire.ShortChannelID) ( + txscript.ScriptClass, btcutil.Address, error) { - return lnwallet.FetchPKScriptWithQuit(d.cfg.ChainIO, chanID, d.quit) + pkScript, err := lnwallet.FetchPKScriptWithQuit( + d.cfg.ChainIO, chanID, d.quit, + ) + if err != nil { + return txscript.WitnessUnknownTy, nil, err + } + + scriptClass, addrs, _, err := txscript.ExtractPkScriptAddrs( + pkScript, d.cfg.ChainParams, + ) + if err != nil { + return txscript.WitnessUnknownTy, nil, err + } + + if len(addrs) != 1 { + return txscript.WitnessUnknownTy, nil, fmt.Errorf("expected "+ + "1 address, got: %d", len(addrs)) + } + + return scriptClass, addrs[0], nil } // addNode processes the given node announcement, and adds it to our channel @@ -2541,16 +2557,16 @@ func (d *AuthenticatedGossiper) handleChanAnnouncement(ctx context.Context, ops ...batch.SchedulerOption) ([]networkMsg, bool) { scid := ann.ShortChannelID + chainHash := d.cfg.ChainParams.GenesisHash log.Debugf("Processing ChannelAnnouncement1: peer=%v, short_chan_id=%v", nMsg.peer, scid.ToUint64()) // We'll ignore any channel announcements that target any chain other // than the set of chains we know of. - if !bytes.Equal(ann.ChainHash[:], d.cfg.ChainHash[:]) { + if !bytes.Equal(ann.ChainHash[:], chainHash[:]) { err := fmt.Errorf("ignoring ChannelAnnouncement1 from chain=%v"+ - ", gossiper on chain=%v", ann.ChainHash, - d.cfg.ChainHash) + ", gossiper on chain=%v", ann.ChainHash, chainHash) log.Errorf(err.Error()) key := newRejectCacheKey( @@ -2971,11 +2987,13 @@ func (d *AuthenticatedGossiper) handleChanUpdate(ctx context.Context, log.Debugf("Processing ChannelUpdate: peer=%v, short_chan_id=%v, ", nMsg.peer, upd.ShortChannelID.ToUint64()) + chainHash := d.cfg.ChainParams.GenesisHash + // We'll ignore any channel updates that target any chain other than // the set of chains we know of. - if !bytes.Equal(upd.ChainHash[:], d.cfg.ChainHash[:]) { + if !bytes.Equal(upd.ChainHash[:], chainHash[:]) { err := fmt.Errorf("ignoring ChannelUpdate from chain=%v, "+ - "gossiper on chain=%v", upd.ChainHash, d.cfg.ChainHash) + "gossiper on chain=%v", upd.ChainHash, chainHash) log.Errorf(err.Error()) key := newRejectCacheKey( @@ -3748,7 +3766,7 @@ func (d *AuthenticatedGossiper) validateFundingTransaction(_ context.Context, // Before we can add the channel to the channel graph, we need to obtain // the full funding outpoint that's encoded within the channel ID. fundingTx, err := lnwallet.FetchFundingTxWrapper( - d.cfg.ChainIO, &scid, d.quit, + d.cfg.ChainIO, scid, d.quit, ) if err != nil { //nolint:ll diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index d0afa7d31..63369f4d6 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -18,6 +18,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/davecgh/go-spew/spew" @@ -620,6 +621,7 @@ func createUpdateAnnouncement(blockHeight uint32, htlcMinMsat := lnwire.MilliSatoshi(100) a := &lnwire.ChannelUpdate1{ + ChainHash: *chaincfg.MainNetParams.GenesisHash, ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, }, @@ -772,6 +774,7 @@ func (ctx *testCtx) createAnnouncementWithoutProof(blockHeight uint32, } a := &lnwire.ChannelAnnouncement1{ + ChainHash: *chaincfg.MainNetParams.GenesisHash, ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, TxIndex: 0, @@ -938,8 +941,9 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) ( } gossiper := New(Config{ - ChainIO: chain, - Notifier: notifier, + ChainIO: chain, + ChainParams: &chaincfg.MainNetParams, + Notifier: notifier, Broadcast: func(senders map[route.Vertex]struct{}, msgs ...lnwire.Message) error { @@ -1669,6 +1673,7 @@ func TestSignatureAnnouncementRetryAtStartup(t *testing.T) { //nolint:ll gossiper := New(Config{ + ChainParams: &chaincfg.MainNetParams, Notifier: tCtx.gossiper.cfg.Notifier, Broadcast: tCtx.gossiper.cfg.Broadcast, NotifyWhenOnline: tCtx.gossiper.reliableSender.cfg.NotifyWhenOnline, diff --git a/lnwallet/interface.go b/lnwallet/interface.go index 9255c513e..f5a717d3f 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -760,7 +760,7 @@ func SupportedWallets() []string { // FetchFundingTxWrapper is a wrapper around FetchFundingTx, except that it will // exit when the supplied quit channel is closed. -func FetchFundingTxWrapper(chain BlockChainIO, chanID *lnwire.ShortChannelID, +func FetchFundingTxWrapper(chain BlockChainIO, chanID lnwire.ShortChannelID, quit chan struct{}) (*wire.MsgTx, error) { txChan := make(chan *wire.MsgTx, 1) @@ -795,7 +795,7 @@ func FetchFundingTxWrapper(chain BlockChainIO, chanID *lnwire.ShortChannelID, // TODO(roasbeef): replace with call to GetBlockTransaction? (would allow to // later use getblocktxn). func FetchFundingTx(chain BlockChainIO, - chanID *lnwire.ShortChannelID) (*wire.MsgTx, error) { + chanID lnwire.ShortChannelID) (*wire.MsgTx, error) { // First fetch the block hash by the block number encoded, then use // that hash to fetch the block itself. @@ -826,7 +826,7 @@ func FetchFundingTx(chain BlockChainIO, // FetchPKScriptWithQuit fetches the output script for the given SCID and exits // early with an error if the provided quit channel is closed before // completion. -func FetchPKScriptWithQuit(chain BlockChainIO, chanID *lnwire.ShortChannelID, +func FetchPKScriptWithQuit(chain BlockChainIO, chanID lnwire.ShortChannelID, quit chan struct{}) ([]byte, error) { tx, err := FetchFundingTxWrapper(chain, chanID, quit) @@ -835,7 +835,7 @@ func FetchPKScriptWithQuit(chain BlockChainIO, chanID *lnwire.ShortChannelID, } outputLocator := chanvalidate.ShortChanIDChanLocator{ - ID: *chanID, + ID: chanID, } output, _, err := outputLocator.Locate(tx) diff --git a/lnwire/announcement_signatures_2.go b/lnwire/announcement_signatures_2.go index 6e893dafd..04e4c0a9f 100644 --- a/lnwire/announcement_signatures_2.go +++ b/lnwire/announcement_signatures_2.go @@ -3,6 +3,8 @@ package lnwire import ( "bytes" "io" + + "github.com/lightningnetwork/lnd/tlv" ) // AnnounceSignatures2 is a direct message between two endpoints of a @@ -14,27 +16,40 @@ type AnnounceSignatures2 struct { // Channel id is better for users and debugging and short channel id is // used for quick test on existence of the particular utxo inside the // blockchain, because it contains information about block. - ChannelID ChannelID + ChannelID tlv.RecordT[tlv.TlvType0, ChannelID] // ShortChannelID is the unique description of the funding transaction. // It is constructed with the most significant 3 bytes as the block // height, the next 3 bytes indicating the transaction index within the // block, and the least significant two bytes indicating the output // index which pays to the channel. - ShortChannelID ShortChannelID + ShortChannelID tlv.RecordT[tlv.TlvType2, ShortChannelID] // PartialSignature is the combination of the partial Schnorr signature // created for the node's bitcoin key with the partial signature created // for the node's node ID key. - PartialSignature PartialSig + PartialSignature tlv.RecordT[tlv.TlvType4, PartialSig] - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData ExtraOpaqueData + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +// NewAnnSigs2 is a constructor for AnnounceSignatures2. +func NewAnnSigs2(chanID ChannelID, scid ShortChannelID, + partialSig PartialSig) *AnnounceSignatures2 { + + return &AnnounceSignatures2{ + ChannelID: tlv.NewRecordT[tlv.TlvType0, ChannelID](chanID), + ShortChannelID: tlv.NewRecordT[tlv.TlvType2, ShortChannelID]( + scid, + ), + PartialSignature: tlv.NewRecordT[tlv.TlvType4, PartialSig]( + partialSig, + ), + ExtraSignedFields: make(ExtraSignedFields), + } } // A compile time check to ensure AnnounceSignatures2 implements the @@ -45,17 +60,30 @@ var _ Message = (*AnnounceSignatures2)(nil) // lnwire.SizeableMessage interface. var _ SizeableMessage = (*AnnounceSignatures2)(nil) +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.PureTLVMessage interface. +var _ PureTLVMessage = (*AnnounceSignatures2)(nil) + // Decode deserializes a serialized AnnounceSignatures2 stored in the passed // io.Reader observing the specified protocol version. // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures2) Decode(r io.Reader, _ uint32) error { - return ReadElements(r, - &a.ChannelID, - &a.ShortChannelID, - &a.PartialSignature, - &a.ExtraOpaqueData, - ) + stream, err := tlv.NewStream(ProduceRecordsSorted( + &a.ChannelID, &a.ShortChannelID, &a.PartialSignature, + )...) + if err != nil { + return err + } + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + a.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil } // Encode serializes the target AnnounceSignatures2 into the passed io.Writer @@ -63,19 +91,7 @@ func (a *AnnounceSignatures2) Decode(r io.Reader, _ uint32) error { // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures2) Encode(w *bytes.Buffer, _ uint32) error { - if err := WriteChannelID(w, a.ChannelID); err != nil { - return err - } - - if err := WriteShortChannelID(w, a.ShortChannelID); err != nil { - return err - } - - if err := WriteElement(w, a.PartialSignature); err != nil { - return err - } - - return WriteBytes(w, a.ExtraOpaqueData) + return EncodePureTLVMessage(a, w) } // MsgType returns the integer uniquely identifying this message type on the @@ -93,16 +109,34 @@ func (a *AnnounceSignatures2) SerializedSize() (uint32, error) { return MessageSerializedSize(a) } +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. +// +// NOTE: this is part of the PureTLVMessage interface. +func (a *AnnounceSignatures2) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &a.ChannelID, &a.ShortChannelID, + &a.PartialSignature, + } + + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(a.ExtraSignedFields), + )...) + + return ProduceRecordsSorted(recordProducers...) +} + // SCID returns the ShortChannelID of the channel. // // NOTE: this is part of the AnnounceSignatures interface. func (a *AnnounceSignatures2) SCID() ShortChannelID { - return a.ShortChannelID + return a.ShortChannelID.Val } // ChanID returns the ChannelID identifying the channel. // // NOTE: this is part of the AnnounceSignatures interface. func (a *AnnounceSignatures2) ChanID() ChannelID { - return a.ChannelID + return a.ChannelID.Val } diff --git a/lnwire/announcement_signatures_2_test.go b/lnwire/announcement_signatures_2_test.go new file mode 100644 index 000000000..6b945edcf --- /dev/null +++ b/lnwire/announcement_signatures_2_test.go @@ -0,0 +1,78 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestAnnSigs2EncodeDecode tests the encoding and decoding of the +// AnnounceSignatures2 message using hardcoded byte slices. +func TestAnnSigs2EncodeDecode(t *testing.T) { + t.Parallel() + + // We'll create a raw byte stream that represents a valid + // AnnounceSignatures2 message with various known and unknown fields in + // the signed TLV ranges. + var rawBytes []byte + + // ChannelID. + rawBytes = append(rawBytes, []byte{ + 0x00, // type + 0x20, // length + }...) + rawBytes = append(rawBytes, make([]byte, 32)...) // value + + // ShortChannelID. + rawBytes = append(rawBytes, []byte{ + 0x02, // type + 0x08, // length + 0, 0, 1, 0, 0, 2, 0, 3, // value + }...) + + // PartialSignature. + rawBytes = append(rawBytes, []byte{ + 0x04, // type + 0x20, // length + }...) + rawBytes = append(rawBytes, make([]byte, 32)...) // value + + // Extra field in the first signed range. + rawBytes = append(rawBytes, []byte{ + 0x30, // type + 0x02, // length + 0xab, 0xcd, // value + }...) + + w := new(bytes.Buffer) + var buf [8]byte + err := tlv.WriteVarInt(w, pureTLVSignedSecondRangeStart+1, &buf) + require.NoError(t, err) + + // Extra field in the second signed range. + rawBytes = append(rawBytes, w.Bytes()...) // type + rawBytes = append(rawBytes, []byte{ + 0x02, // length + 0x79, 0x79, // value + }...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &AnnounceSignatures2{} + r := bytes.NewReader(rawBytes) + err = msg.Decode(r, 0) + require.NoError(t, err) + + // At this point, we expect 2 extra signed fields. + require.Len(t, msg.ExtraSignedFields, 2) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/channel_announcement_2.go b/lnwire/channel_announcement_2.go index 95af69eda..94474d957 100644 --- a/lnwire/channel_announcement_2.go +++ b/lnwire/channel_announcement_2.go @@ -12,9 +12,6 @@ import ( // ChannelAnnouncement2 message is used to announce the existence of a taproot // channel between two peers in the network. type ChannelAnnouncement2 struct { - // Signature is a Schnorr signature over the TLV stream of the message. - Signature Sig - // ChainHash denotes the target chain that this channel was opened // within. This value should be the genesis hash of the target chain. ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash] @@ -59,74 +56,14 @@ type ChannelAnnouncement2 struct { // the funding output is a pure 2-of-2 MuSig aggregate public key. MerkleRootHash tlv.OptionalRecordT[tlv.TlvType16, [32]byte] - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData ExtraOpaqueData -} + // Signature is a Schnorr signature over serialised signed-range TLV + // stream of the message. + Signature tlv.RecordT[tlv.TlvType160, Sig] -// Decode deserializes a serialized AnnounceSignatures1 stored in the passed -// io.Reader observing the specified protocol version. -// -// This is part of the lnwire.Message interface. -func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { - err := ReadElement(r, &c.Signature) - if err != nil { - return err - } - c.Signature.ForceSchnorr() - - return c.DecodeTLVRecords(r) -} - -// DecodeTLVRecords decodes only the TLV section of the message. -func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error { - // First extract into extra opaque data. - var tlvRecords ExtraOpaqueData - if err := ReadElements(r, &tlvRecords); err != nil { - return err - } - - var ( - chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() - btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() - btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() - merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]() - ) - typeMap, err := tlvRecords.ExtractRecords( - &chainHash, &c.Features, &c.ShortChannelID, &c.Capacity, - &c.NodeID1, &c.NodeID2, &btcKey1, &btcKey2, &merkleRootHash, - ) - if err != nil { - return err - } - - // By default, the chain-hash is the bitcoin mainnet genesis block hash. - c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash - if _, ok := typeMap[c.ChainHash.TlvType()]; ok { - c.ChainHash.Val = chainHash.Val - } - - if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok { - c.BitcoinKey1 = tlv.SomeRecordT(btcKey1) - } - - if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok { - c.BitcoinKey2 = tlv.SomeRecordT(btcKey2) - } - - if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok { - c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash) - } - - if len(tlvRecords) != 0 { - c.ExtraOpaqueData = tlvRecords - } - - return c.ExtraOpaqueData.ValidateTLV() + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields } // Encode serializes the target AnnounceSignatures1 into the passed io.Writer @@ -134,21 +71,27 @@ func (c *ChannelAnnouncement2) DecodeTLVRecords(r io.Reader) error { // // This is part of the lnwire.Message interface. func (c *ChannelAnnouncement2) Encode(w *bytes.Buffer, _ uint32) error { - _, err := w.Write(c.Signature.RawBytes()) - if err != nil { - return err - } - _, err = c.DataToSign() - if err != nil { - return err - } - - return WriteBytes(w, c.ExtraOpaqueData) + return EncodePureTLVMessage(c, w) } -// DataToSign encodes the data to be signed into the ExtraOpaqueData member and -// returns it. -func (c *ChannelAnnouncement2) DataToSign() ([]byte, error) { +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. +// +// NOTE: this is part of the PureTLVMessage interface. +func (c *ChannelAnnouncement2) AllRecords() []tlv.Record { + recordProducers := append( + c.allNonSignatureRecordProducers(), &c.Signature, + ) + + return ProduceRecordsSorted(recordProducers...) +} + +// allNonSignatureRecordProducers returns all the TLV record producers for the +// message except the signature record producer. +// +//nolint:ll +func (c *ChannelAnnouncement2) allNonSignatureRecordProducers() []tlv.RecordProducer { // The chain-hash record is only included if it is _not_ equal to the // bitcoin mainnet genisis block hash. var recordProducers []tlv.RecordProducer @@ -178,12 +121,126 @@ func (c *ChannelAnnouncement2) DataToSign() ([]byte, error) { }, ) - err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(c.ExtraSignedFields), + )...) + + return recordProducers +} + +// Decode deserializes a serialized AnnounceSignatures1 stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ChannelAnnouncement2) Decode(r io.Reader, _ uint32) error { + var ( + chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() + btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() + merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]() + ) + stream, err := tlv.NewStream(ProduceRecordsSorted( + &chainHash, + &c.Features, + &c.ShortChannelID, + &c.Capacity, + &c.NodeID1, + &c.NodeID2, + &btcKey1, + &btcKey2, + &merkleRootHash, + &c.Signature, + )...) if err != nil { - return nil, err + return err + } + c.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err } - return c.ExtraOpaqueData, nil + // By default, the chain-hash is the bitcoin mainnet genesis block hash. + c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if _, ok := typeMap[c.ChainHash.TlvType()]; ok { + c.ChainHash.Val = chainHash.Val + } + + if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok { + c.BitcoinKey1 = tlv.SomeRecordT(btcKey1) + } + + if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok { + c.BitcoinKey2 = tlv.SomeRecordT(btcKey2) + } + + if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok { + c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash) + } + + c.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// DecodeNonSigTLVRecords decodes only the TLV section of the message. +func (c *ChannelAnnouncement2) DecodeNonSigTLVRecords(r io.Reader) error { + var ( + chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + btcKey1 = tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() + btcKey2 = tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() + merkleRootHash = tlv.ZeroRecordT[tlv.TlvType16, [32]byte]() + ) + stream, err := tlv.NewStream(ProduceRecordsSorted( + &chainHash, + &c.Features, + &c.ShortChannelID, + &c.Capacity, + &c.NodeID1, + &c.NodeID2, + &btcKey1, + &btcKey2, + &merkleRootHash, + )...) + if err != nil { + return err + } + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + // By default, the chain-hash is the bitcoin mainnet genesis block hash. + c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if _, ok := typeMap[c.ChainHash.TlvType()]; ok { + c.ChainHash.Val = chainHash.Val + } + + if _, ok := typeMap[c.BitcoinKey1.TlvType()]; ok { + c.BitcoinKey1 = tlv.SomeRecordT(btcKey1) + } + + if _, ok := typeMap[c.BitcoinKey2.TlvType()]; ok { + c.BitcoinKey2 = tlv.SomeRecordT(btcKey2) + } + + if _, ok := typeMap[c.MerkleRootHash.TlvType()]; ok { + c.MerkleRootHash = tlv.SomeRecordT(merkleRootHash) + } + + c.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// EncodeAllNonSigFields encodes the entire message to the given writer but +// excludes the signature field. +func (c *ChannelAnnouncement2) EncodeAllNonSigFields(w io.Writer) error { + return EncodeRecordsTo( + w, ProduceRecordsSorted(c.allNonSignatureRecordProducers()...), + ) } // MsgType returns the integer uniquely identifying this message type on the @@ -209,6 +266,10 @@ var _ Message = (*ChannelAnnouncement2)(nil) // lnwire.SizeableMessage interface. var _ SizeableMessage = (*ChannelAnnouncement2)(nil) +// A compile time check to ensure ChannelAnnouncement2 implements the +// lnwire.PureTLVMessage interface. +var _ PureTLVMessage = (*ChannelAnnouncement2)(nil) + // Node1KeyBytes returns the bytes representing the public key of node 1 in the // channel. // diff --git a/lnwire/channel_announcement_2_test.go b/lnwire/channel_announcement_2_test.go new file mode 100644 index 000000000..40d255b55 --- /dev/null +++ b/lnwire/channel_announcement_2_test.go @@ -0,0 +1,103 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestChanAnn2EncodeDecode tests the encoding and decoding of the +// ChannelAnnouncement2 message using hardcoded byte slices. +func TestChanAnn2EncodeDecode(t *testing.T) { + t.Parallel() + + // We'll create a raw byte stream that represents a valid + // ChannelAnnouncement2 message with various known and unknown fields in + // the signed TLV ranges along with the signature in the unsigned range. + rawBytes := []byte{ + // ChainHash record (optional, not mainnet). + 0x00, // type. + 0x20, // length. + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + + // Features record. + 0x02, // type. + 0x02, // length. + 0x1, 0x2, // value. + + // ShortChannelID record. + 0x04, // type. + 0x08, // length. + 0x0, 0x0, 0x1, 0x0, 0x0, 0x2, 0x0, 0x3, // value. + + // Unknown TLV record. + 0x05, // type. + 0x02, // length. + 0xab, 0xcd, // value. + + // Capacity record. + 0x06, // type. + 0x08, // length. + 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x86, 0xa0, // value: 100000. + + // NodeID1 record. + 0x08, // type. + 0x21, // length. + 0x2, 0x28, 0xf2, 0xaf, 0xa, 0xbe, 0x32, 0x24, 0x3, 0x48, 0xf, + 0xb3, 0xee, 0x17, 0x2f, 0x7f, 0x16, 0x1, 0xe6, 0x7d, 0x1d, 0xa6, + 0xca, 0xd4, 0xb, 0x54, 0xc4, 0x46, 0x8d, 0x48, 0x23, 0x6c, 0x39, + + // NodeID2 record. + 0x0a, // type. + 0x21, // length. + 0x3, 0x28, 0xf2, 0xaf, 0xa, 0xbe, 0x32, 0x24, 0x3, 0x48, 0xf, + 0xb3, 0xee, 0x17, 0x2f, 0x7f, 0x16, 0x1, 0xe6, 0x7d, 0x1d, 0xa6, + 0xca, 0xd4, 0xb, 0x54, 0xc4, 0x46, 0x8d, 0x48, 0x23, 0x6c, 0x39, + + // Unknown TLV record. + 0x6f, // type. + 0x2, // length. + 0x79, 0x79, // value. + + // Signature. + 0xa0, // type. + 0x40, // length. + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, + 0xc, 0xd, 0xe, 0xf, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, + 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, + 0x3f, // value. + } + secondSignedRangeType := new(bytes.Buffer) + var buf [8]byte + err := tlv.WriteVarInt( + secondSignedRangeType, pureTLVSignedSecondRangeStart+1, &buf, + ) + require.NoError(t, err) + rawBytes = append(rawBytes, secondSignedRangeType.Bytes()...) // type. + rawBytes = append(rawBytes, []byte{ + 0x02, // length. + 0x79, 0x79, // value. + }...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &ChannelAnnouncement2{} + r := bytes.NewReader(rawBytes) + err = msg.Decode(r, 0) + require.NoError(t, err) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/channel_id.go b/lnwire/channel_id.go index 1615eb747..5c9eca34f 100644 --- a/lnwire/channel_id.go +++ b/lnwire/channel_id.go @@ -3,10 +3,12 @@ package lnwire import ( "encoding/binary" "encoding/hex" + "io" "math" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -36,6 +38,40 @@ func (c ChannelID) String() string { return hex.EncodeToString(c[:]) } +// Record returns a TLV record that can be used to encode/decode a ChannelID +// to/from a TLV stream. +func (c *ChannelID) Record() tlv.Record { + return tlv.MakeStaticRecord(0, c, 32, encodeChannelID, decodeChannelID) +} + +func encodeChannelID(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*ChannelID); ok { + bigSize := [32]byte(*v) + + return tlv.EBytes32(w, &bigSize, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.ChannelID") +} + +func decodeChannelID(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*ChannelID); ok { + var id [32]byte + err := tlv.DBytes32(r, &id, buf, l) + if err != nil { + return err + } + + *v = id + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "lnwire.ChannelID", l, l) +} + // NewChanIDFromOutPoint converts a target OutPoint into a ChannelID that is // usable within the network. In order to convert the OutPoint into a ChannelID, // we XOR the lower 2-bytes of the txid within the OutPoint with the big-endian diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go index 343af6b1e..b832bc58d 100644 --- a/lnwire/channel_update_2.go +++ b/lnwire/channel_update_2.go @@ -22,10 +22,6 @@ const ( // HTLCs and other parameters. This message is also used to redeclare initially // set channel parameters. type ChannelUpdate2 struct { - // Signature is used to validate the announced data and prove the - // ownership of node id. - Signature Sig - // ChainHash denotes the target chain that this channel was opened // within. This value should be the genesis hash of the target chain. // Along with the short channel ID, this uniquely identifies the @@ -74,10 +70,22 @@ type ChannelUpdate2 struct { // millionth of a satoshi. FeeProportionalMillionths tlv.RecordT[tlv.TlvType18, uint32] - // ExtraOpaqueData is the set of data that was appended to this message - // to fill out the full maximum transport message size. These fields can - // be used to specify optional data such as custom TLV fields. - ExtraOpaqueData ExtraOpaqueData + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +// Encode serializes the target ChannelUpdate2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(c, w) } // Decode deserializes a serialized ChannelUpdate2 stored in the passed @@ -85,17 +93,6 @@ type ChannelUpdate2 struct { // // This is part of the lnwire.Message interface. func (c *ChannelUpdate2) Decode(r io.Reader, _ uint32) error { - err := ReadElement(r, &c.Signature) - if err != nil { - return err - } - c.Signature.ForceSchnorr() - - return c.DecodeTLVRecords(r) -} - -// DecodeTLVRecords decodes only the TLV section of the message. -func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { // First extract into extra opaque data. var tlvRecords ExtraOpaqueData if err := ReadElements(r, &tlvRecords); err != nil { @@ -111,10 +108,12 @@ func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { &secondPeer, &c.CLTVExpiryDelta, &c.HTLCMinimumMsat, &c.HTLCMaximumMsat, &c.FeeBaseMsat, &c.FeeProportionalMillionths, + &c.Signature, ) if err != nil { return err } + c.Signature.Val.ForceSchnorr() // By default, the chain-hash is the bitcoin mainnet genesis block hash. c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash @@ -150,38 +149,21 @@ func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { c.FeeProportionalMillionths.Val = defaultFeeProportionalMillionths //nolint:ll } - if len(tlvRecords) != 0 { - c.ExtraOpaqueData = tlvRecords - } + c.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) - return c.ExtraOpaqueData.ValidateTLV() + return nil } -// Encode serializes the target ChannelUpdate2 into the passed io.Writer -// observing the protocol version specified. +// AllRecords returns all the TLV records for the message. This will include all +// the records we know about along with any that we don't know about but that +// fall in the signed TLV range. // -// This is part of the lnwire.Message interface. -func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { - _, err := w.Write(c.Signature.RawBytes()) - if err != nil { - return err - } +// NOTE: this is part of the PureTLVMessage interface. +func (c *ChannelUpdate2) AllRecords() []tlv.Record { + var recordProducers []tlv.RecordProducer - _, err = c.DataToSign() - if err != nil { - return err - } - - return WriteBytes(w, c.ExtraOpaqueData) -} - -// DataToSign is used to retrieve part of the announcement message which should -// be signed. For the ChannelUpdate2 message, this includes the serialised TLV -// records. -func (c *ChannelUpdate2) DataToSign() ([]byte, error) { // The chain-hash record is only included if it is _not_ equal to the // bitcoin mainnet genisis block hash. - var recordProducers []tlv.RecordProducer if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() hash.Val = c.ChainHash.Val @@ -190,7 +172,7 @@ func (c *ChannelUpdate2) DataToSign() ([]byte, error) { } recordProducers = append(recordProducers, - &c.ShortChannelID, &c.BlockHeight, + &c.ShortChannelID, &c.BlockHeight, &c.Signature, ) // Only include the disable flags if any bit is set. @@ -225,12 +207,11 @@ func (c *ChannelUpdate2) DataToSign() ([]byte, error) { ) } - err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) - if err != nil { - return nil, err - } + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(c.ExtraSignedFields), + )...) - return c.ExtraOpaqueData, nil + return ProduceRecordsSorted(recordProducers...) } // MsgType returns the integer uniquely identifying this message type on the @@ -248,14 +229,14 @@ func (c *ChannelUpdate2) SerializedSize() (uint32, error) { return MessageSerializedSize(c) } -func (c *ChannelUpdate2) ExtraData() ExtraOpaqueData { - return c.ExtraOpaqueData -} - // A compile time check to ensure ChannelUpdate2 implements the // lnwire.Message interface. var _ Message = (*ChannelUpdate2)(nil) +// A compile time check to ensure ChannelUpdate2 implements the +// lnwire.PureTLVMessage interface. +var _ PureTLVMessage = (*ChannelUpdate2)(nil) + // SCID returns the ShortChannelID of the channel that the update applies to. // // NOTE: this is part of the ChannelUpdate interface. diff --git a/lnwire/channel_update_2_test.go b/lnwire/channel_update_2_test.go new file mode 100644 index 000000000..4e771d89f --- /dev/null +++ b/lnwire/channel_update_2_test.go @@ -0,0 +1,119 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestChanUpdate2EncodeDecode tests the encoding and decoding of the +// ChannelUpdate2 message using hardcoded byte slices. +func TestChanUpdate2EncodeDecode(t *testing.T) { + t.Parallel() + + // We'll create a raw byte stream that represents a valid ChannelUpdate2 + // message. This includes the signature and a TLV stream with both known + // and unknown records. + rawBytes := []byte{ + // ChainHash record (optional, not mainnet). + 0x0, // type. + 0x20, // length. + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + + // ShortChannelID record. + 0x2, // type. + 0x8, // length. + 0x0, 0x0, 0x1, 0x0, 0x0, 0x2, 0x0, 0x3, // value. + + // BlockHeight record. + 0x4, // type. + 0x4, // length. + 0x0, 0x0, 0x1, 0x0, // value. + + // DisabledFlags record. + 0x6, // type. + 0x1, // length. + 0x1, // value. + + // SecondPeer record. + 0x8, // type. + 0x0, // length. + + // Unknown odd-type TLV record. + 0x9, // type. + 0x2, // length. + 0xab, 0xcd, // value. + + // CLTVExpiryDelta record. + 0xa, // type. + 0x2, // length. + 0x0, 0x10, // value. + + // HTLCMinimumMsat record. + 0xc, // type. + 0x5, // length. + 0xfe, 0x0, 0xf, 0x42, 0x40, // value (BigSize: 1_000_000). + + // HTLCMaximumMsat record. + 0xe, // type. + 0x5, // length. + 0xfe, 0x0, 0xf, 0x42, 0x40, // value (BigSize: 1_000_000). + + // FeeBaseMsat record. + 0x10, // type. + 0x4, // length. + 0x0, 0x0, 0x1, 0x0, // value. + + // FeeProportionalMillionths record. + 0x12, // type. + 0x4, // length. + 0x0, 0x0, 0x1, 0x0, // value. + + // Extra Opaque Data - Unknown Record. + 0x14, // type. + 0x2, // length. + 0x79, 0x79, // value. + + // Signature. + 0xa0, // type. + 0x40, // length. + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, + 0xc, 0xd, 0xe, 0xf, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, + 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, + 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, + 0x3f, // value + } + + secondSignedRangeType := new(bytes.Buffer) + var buf [8]byte + err := tlv.WriteVarInt( + secondSignedRangeType, pureTLVSignedSecondRangeStart+1, &buf, + ) + require.NoError(t, err) + rawBytes = append(rawBytes, secondSignedRangeType.Bytes()...) // type. + rawBytes = append(rawBytes, []byte{ + 0x02, // length. + 0x79, 0x79, // value. + }...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &ChannelUpdate2{} + r := bytes.NewReader(rawBytes) + err = msg.Decode(r, 0) + require.NoError(t, err) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/pure_tlv.go b/lnwire/pure_tlv.go new file mode 100644 index 000000000..8e6f7bd9f --- /dev/null +++ b/lnwire/pure_tlv.go @@ -0,0 +1,105 @@ +package lnwire + +import ( + "bytes" + + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // pureTLVUnsignedRangeOneStart defines the start of the first unsigned + // TLV range used for pure TLV messages. The range is inclusive of this + // number. + pureTLVUnsignedRangeOneStart = 160 + + // pureTLVSignedSecondRangeStart defines the start of the second signed + // TLV range used for pure TLV messages. The range is inclusive of this + // number. Note that the first range is the inclusive range of 0-159. + pureTLVSignedSecondRangeStart = 1000000000 + + // pureTLVUnsignedRangeTwoStart defines the start of the second unsigned + // TLV range used for pure TLV message. + pureTLVUnsignedRangeTwoStart = 3000000000 +) + +// PureTLVMessage describes an LN message that is a pure TLV stream. If the +// message includes a signature, it will sign all the TLV records in the +// inclusive ranges: 0 to 159 and 1000000000 to 2999999999. +type PureTLVMessage interface { + // AllRecords returns all the TLV records for the message. This will + // include all the records we know about along with any that we don't + // know about but that fall in the signed TLV range. + AllRecords() []tlv.Record +} + +// EncodePureTLVMessage encodes the given PureTLVMessage to the given buffer. +func EncodePureTLVMessage(msg PureTLVMessage, buf *bytes.Buffer) error { + return EncodeRecordsTo(buf, msg.AllRecords()) +} + +// SerialiseFieldsToSign serialises all the records from the given +// PureTLVMessage that fall within the signed TLV range. +func SerialiseFieldsToSign(msg PureTLVMessage) ([]byte, error) { + // Filter out all the fields not in the signed ranges. + var signedRecords []tlv.Record + for _, record := range msg.AllRecords() { + if InUnsignedRange(record.Type()) { + continue + } + + signedRecords = append(signedRecords, record) + } + + var buf bytes.Buffer + if err := EncodeRecordsTo(&buf, signedRecords); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// InUnsignedRange returns true if the given TLV type falls outside the TLV +// ranges that the signature of a pure TLV message will cover. +func InUnsignedRange(t tlv.Type) bool { + return (t >= pureTLVUnsignedRangeOneStart && + t < pureTLVSignedSecondRangeStart) || + t >= pureTLVUnsignedRangeTwoStart +} + +// ExtraSignedFields is a type that stores a map from TLV types in the signed +// range (for PureMessages) to their corresponding serialised values. This type +// can be used to keep around data that we don't yet understand but that we need +// for re-composing the wire message since the signature covers these fields. +type ExtraSignedFields map[uint64][]byte + +// ExtraSignedFieldsFromTypeMap is a helper that can be used alongside calls to +// the tlv.Stream DecodeWithParsedTypesP2P or DecodeWithParsedTypes methods to +// extract the tlv type and value pairs in the defined PureTLVMessage signed +// range which we have not handled with any of our defined Records. These +// methods will return a tlv.TypeMap containing the records that were extracted +// from an io.Reader. If the record was know and handled by a defined record, +// then the value accompanying the record's type in the map will be nil. +// Otherwise, if the record was unhandled, it will be non-nil. +func ExtraSignedFieldsFromTypeMap(m tlv.TypeMap) ExtraSignedFields { + extraFields := make(ExtraSignedFields) + for t, v := range m { + // If the value in the type map is nil, then it indicates that + // we know this type, and it was handled by one of the records + // we passed to the decode function vai the TLV stream. + if v == nil { + continue + } + + // No need to keep this field if it is unknown to us and is not + // in the sign range. + if InUnsignedRange(t) { + continue + } + + // Otherwise, this is an un-handled type, so we keep track of + // it for signature validation and re-encoding later on. + extraFields[uint64(t)] = v + } + + return extraFields +} diff --git a/lnwire/pure_tlv_test.go b/lnwire/pure_tlv_test.go new file mode 100644 index 000000000..a81a89ecb --- /dev/null +++ b/lnwire/pure_tlv_test.go @@ -0,0 +1,389 @@ +package lnwire + +import ( + "bytes" + "io" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/require" +) + +// TestPureTLVMessages tests the forwards compatibility of two versions of the +// same Lightning Network message that uses the Pure TLV format. This in essence +// tests that and older client is able to verify the signature over relevant +// data in a newer client's message. +func TestPureTLVMessage(t *testing.T) { + t.Parallel() + + var ( + _, pkA = btcec.PrivKeyFromBytes([]byte{1}) + _, pkB = btcec.PrivKeyFromBytes([]byte{2}) + capacity = MilliSatoshi(100) + ) + + // Test encode and decode of MsgV1 as is. + t.Run("Encode and Decode of MsgV1", func(t *testing.T) { + t.Parallel() + + msgOld := newMsgV1(pkA, &capacity) + + buf := bytes.NewBuffer(nil) + require.NoError(t, msgOld.Encode(buf, 0)) + + var msgOld2 MsgV1 + require.NoError(t, msgOld2.Decode(buf, 0)) + + require.Equal(t, msgOld, &msgOld2) + }) + + // Test encode and decode of MsgV2 as is. + t.Run("Encode and Decode of MsgV2", func(t *testing.T) { + t.Parallel() + + msgNew := newMsgV2( + pkA, &capacity, pkB, []byte{1, 2, 3, 4}, 90, 100, true, + ) + + buf := bytes.NewBuffer(nil) + require.NoError(t, msgNew.Encode(buf, 0)) + + var msgNew2 MsgV2 + require.NoError(t, msgNew2.Decode(buf, 0)) + + require.Equal(t, msgNew, &msgNew2) + }) + + // Create a MsgV2 and decode it into a MsgV1. Both the new client + // (MsgV2) and old client (MsgV1) should be able to generate the same + // digest that will be used to create and validate the signture. + t.Run("Encode MsgV2 and decode via MsgV1", func(t *testing.T) { + t.Parallel() + + var ( + buf = bytes.NewBuffer(nil) + msgV2 = newMsgV2( + pkA, &capacity, pkB, []byte{1, 2, 3, 4}, 100, + 90, true, + ) + ) + require.NoError(t, msgV2.Encode(buf, 0)) + + // Get the serialised bytes that would be signed for msgV2. + signData1, err := SerialiseFieldsToSign(msgV2) + require.NoError(t, err) + + // Decoding via the old message should store some of the extra + // fields. + var msgV1 MsgV1 + require.NoError(t, msgV1.Decode(buf, 0)) + require.NotEmpty(t, msgV1.ExtraSignedFields) + + // Show that the extra fields map contains unknown fields in the + // signed range but not unknown fields in the unsigned range. + _, ok := msgV1.ExtraSignedFields[uint64(msgV2.Num.TlvType())] //nolint:ll + require.True(t, ok) + _, ok = msgV1.ExtraSignedFields[uint64(msgV2.Other.TlvType())] //nolint:ll + require.False(t, ok) + + // The serialised bytes to verify the signature against should + // be the same though. + signData2, err := SerialiseFieldsToSign(&msgV1) + require.NoError(t, err) + + require.Equal(t, signData1, signData2) + + // Re-encoding via the old message should keep the extra fields. + buf = bytes.NewBuffer(nil) + require.NoError(t, msgV1.Encode(buf, 0)) + + var msgV1ReEncoded MsgV1 + require.NoError(t, msgV1ReEncoded.Decode(buf, 0)) + + require.Equal(t, &msgV1, &msgV1ReEncoded) + }) +} + +// MsgV1 represents a more minimal, first version of a Lightning Network +// message. +type MsgV1 struct { + // Two known fields in the signed range. + NodeKey tlv.RecordT[tlv.TlvType0, *btcec.PublicKey] + Capacity tlv.OptionalRecordT[tlv.TlvType1, MilliSatoshi] + + // Signature in the unsigned range. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +var _ Message = (*MsgV1)(nil) +var _ PureTLVMessage = (*MsgV1)(nil) + +// newMsgV1 is a constructor for MsgV1. +func newMsgV1(nodeKey *btcec.PublicKey, capacity *MilliSatoshi) *MsgV1 { + newMsg := &MsgV1{ + NodeKey: tlv.NewPrimitiveRecord[tlv.TlvType0]( + nodeKey, + ), + Signature: tlv.NewRecordT[tlv.TlvType160]( + testSchnorrSig, + ), + ExtraSignedFields: make(ExtraSignedFields), + } + + if capacity != nil { + newMsg.Capacity = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](*capacity), + ) + } + + return newMsg +} + +// Decode deserializes a serialized MsgV1 in the passed io.Reader. +// +// This is part of the lnwire.Message interface. +func (g *MsgV1) Decode(r io.Reader, _ uint32) error { + var capacity = tlv.ZeroRecordT[tlv.TlvType1, MilliSatoshi]() + stream, err := tlv.NewStream( + ProduceRecordsSorted( + &g.NodeKey, + &capacity, + &g.Signature, + )..., + ) + if err != nil { + return err + } + g.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + if _, ok := typeMap[g.Capacity.TlvType()]; ok { + g.Capacity = tlv.SomeRecordT(capacity) + } + + g.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// Encode serializes the target MsgV1 into the passed buffer. +// +// This is part of the lnwire.Message interface. +func (g *MsgV1) Encode(buf *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(g, buf) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (g *MsgV1) MsgType() MessageType { + return 7777 +} + +// AllRecords returns all the TLV records for the message. This will +// include all the records we know about along with any that we don't +// know about but that fall in the signed TLV range. +// +// This is part of the PureTLVMessage interface. +func (g *MsgV1) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &g.NodeKey, + &g.Signature, + } + recordProducers = append( + recordProducers, + RecordsAsProducers( + tlv.MapToRecords(g.ExtraSignedFields), + )..., + ) + + g.Capacity.WhenSome( + func(capacity tlv.RecordT[tlv.TlvType1, MilliSatoshi]) { + recordProducers = append(recordProducers, &capacity) + }, + ) + + return ProduceRecordsSorted(recordProducers...) +} + +// MsgV2 represents a newer version of MsgV1 which contains more fields both in +// the unsigned and signed TLV ranges. +type MsgV2 struct { + NodeKey tlv.RecordT[tlv.TlvType0, *btcec.PublicKey] + Capacity tlv.OptionalRecordT[tlv.TlvType1, MilliSatoshi] + + // An additional fields (optional) in the signed range. + BitcoinKey tlv.OptionalRecordT[tlv.TlvType3, *btcec.PublicKey] + + // A zero length TLV in the signed range. + SecondPeer tlv.OptionalRecordT[tlv.TlvType5, TrueBoolean] + + // Signature in the unsigned range. + Signature tlv.RecordT[tlv.TlvType160, Sig] + + // Another field in the unsigned range. An older node can throw this + // away. + SPVProof tlv.RecordT[tlv.TlvType161, []byte] + + // A new field in the second signed range. An older node should keep + // this since it is part of the serialised message that is signed. + Num tlv.RecordT[tlv.TlvType1000000000, uint8] + + // Another field in the second unsigned-range. Older nodes may throw + // this away and it won't affect the digest used for signature creation + // and validation. + Other tlv.RecordT[tlv.TlvType3000000000, uint8] + + // Any extra fields in the signed range that we do not yet know about, + // but we need to keep them for signature validation and to produce a + // valid message. + ExtraSignedFields +} + +// newMsgV2 is a constructor for MsgV2. +func newMsgV2(nodeKey *btcec.PublicKey, capacity *MilliSatoshi, + btcKey *btcec.PublicKey, spvProof []byte, num, other uint8, + secondPeer bool) *MsgV2 { + + newMsg := &MsgV2{ + NodeKey: tlv.NewPrimitiveRecord[tlv.TlvType0](nodeKey), + SPVProof: tlv.NewPrimitiveRecord[tlv.TlvType161](spvProof), + Num: tlv.NewPrimitiveRecord[tlv.TlvType1000000000](num), + Other: tlv.NewPrimitiveRecord[tlv.TlvType3000000000](num), + Signature: tlv.NewRecordT[tlv.TlvType160]( + testSchnorrSig, + ), + ExtraSignedFields: make(ExtraSignedFields), + } + + if secondPeer { + newMsg.SecondPeer = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType5](TrueBoolean{}), + ) + } + + if capacity != nil { + newMsg.Capacity = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType1](*capacity), + ) + } + + if btcKey != nil { + newMsg.BitcoinKey = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType3](btcKey), + ) + } + + return newMsg +} + +// Decode deserializes a serialized MsgV2 in the passed io.Reader. +// +// This is part of the lnwire.Message interface. +func (g *MsgV2) Decode(r io.Reader, _ uint32) error { + var ( + capacity = tlv.ZeroRecordT[tlv.TlvType1, MilliSatoshi]() + btcKey = tlv.ZeroRecordT[tlv.TlvType3, *btcec.PublicKey]() + secondPeer = tlv.ZeroRecordT[tlv.TlvType5, TrueBoolean]() + ) + + stream, err := tlv.NewStream( + ProduceRecordsSorted( + &g.NodeKey, + &capacity, + &btcKey, + &secondPeer, + &g.Signature, + &g.SPVProof, + &g.Num, + &g.Other, + )..., + ) + if err != nil { + return err + } + g.Signature.Val.ForceSchnorr() + + typeMap, err := stream.DecodeWithParsedTypesP2P(r) + if err != nil { + return err + } + + if _, ok := typeMap[g.Capacity.TlvType()]; ok { + g.Capacity = tlv.SomeRecordT(capacity) + } + + if _, ok := typeMap[g.SecondPeer.TlvType()]; ok { + g.SecondPeer = tlv.SomeRecordT(secondPeer) + } + + if _, ok := typeMap[g.BitcoinKey.TlvType()]; ok { + g.BitcoinKey = tlv.SomeRecordT(btcKey) + } + + g.ExtraSignedFields = ExtraSignedFieldsFromTypeMap(typeMap) + + return nil +} + +// Encode serializes the target MsgV2 into the passed buffer. +// +// This is part of the lnwire.Message interface. +func (g *MsgV2) Encode(buf *bytes.Buffer, _ uint32) error { + return EncodePureTLVMessage(g, buf) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (g *MsgV2) MsgType() MessageType { + return 7779 +} + +// AllRecords returns all the TLV records for the message. This will +// include all the records we know about along with any that we don't +// know about but that fall in the signed TLV range. +// +// This is part of the PureTLVMessage interface. +func (g *MsgV2) AllRecords() []tlv.Record { + recordProducers := []tlv.RecordProducer{ + &g.NodeKey, + &g.Signature, + &g.SPVProof, + &g.Num, + &g.Other, + } + recordProducers = append(recordProducers, RecordsAsProducers( + tlv.MapToRecords(g.ExtraSignedFields), + )...) + + g.Capacity.WhenSome( + func(c tlv.RecordT[tlv.TlvType1, MilliSatoshi]) { + recordProducers = append(recordProducers, &c) + }, + ) + g.BitcoinKey.WhenSome( + func(key tlv.RecordT[tlv.TlvType3, *btcec.PublicKey]) { + recordProducers = append(recordProducers, &key) + }, + ) + g.SecondPeer.WhenSome( + func(second tlv.RecordT[tlv.TlvType5, TrueBoolean]) { + recordProducers = append(recordProducers, &second) + }, + ) + + return ProduceRecordsSorted(recordProducers...) +} diff --git a/lnwire/test_message.go b/lnwire/test_message.go index 7eb712b74..8f946f190 100644 --- a/lnwire/test_message.go +++ b/lnwire/test_message.go @@ -129,12 +129,29 @@ var _ TestMessage = (*AnnounceSignatures2)(nil) // // This is part of the TestMessage interface. func (a *AnnounceSignatures2) RandTestMessage(t *rapid.T) Message { - return &AnnounceSignatures2{ - ChannelID: RandChannelID(t), - ShortChannelID: RandShortChannelID(t), - PartialSignature: *RandPartialSig(t), - ExtraOpaqueData: RandExtraOpaqueData(t, nil), + var ( + chanID = RandChannelID(t) + scid = RandShortChannelID(t) + pSig = RandPartialSig(t) + ) + + msg := &AnnounceSignatures2{ + ChannelID: tlv.NewRecordT[tlv.TlvType0, ChannelID]( + chanID, + ), + ShortChannelID: tlv.NewRecordT[tlv.TlvType2](scid), + PartialSignature: tlv.NewRecordT[tlv.TlvType4, PartialSig]( + *pSig, + ), + ExtraSignedFields: make(map[uint64][]byte), } + + randRecs, _ := RandSignedRangeRecords(t) + if len(randRecs) > 0 { + msg.ExtraSignedFields = ExtraSignedFields(randRecs) + } + + return msg } // A compile time check to ensure ChannelAnnouncement1 implements the @@ -213,7 +230,6 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { copy(chainHashObj[:], chainHash[:]) msg := &ChannelAnnouncement2{ - Signature: RandSignature(t), ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( chainHashObj, ), @@ -232,10 +248,16 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { NodeID2: tlv.NewPrimitiveRecord[tlv.TlvType10, [33]byte]( nodeID2, ), - ExtraOpaqueData: RandExtraOpaqueData(t, nil), + ExtraSignedFields: make(map[uint64][]byte), } - msg.Signature.ForceSchnorr() + msg.Signature.Val = RandSignature(t) + msg.Signature.Val.ForceSchnorr() + + randRecs, _ := RandSignedRangeRecords(t) + if len(randRecs) > 0 { + msg.ExtraSignedFields = ExtraSignedFields(randRecs) + } // Randomly include optional fields if rapid.Bool().Draw(t, "includeBitcoinKey1") { @@ -411,7 +433,7 @@ func (a *ChannelUpdate1) RandTestMessage(t *rapid.T) Message { // include an inbound fee, then we will also set the record in the // extra opaque data. var ( - customRecords, _ = RandCustomRecords(t, nil, false) + customRecords, _ = RandCustomRecords(t, nil) inboundFee tlv.OptionalRecordT[tlv.TlvType55555, Fee] ) includeInboundFee := rapid.Bool().Draw(t, "includeInboundFee") @@ -513,7 +535,6 @@ func (c *ChannelUpdate2) RandTestMessage(t *rapid.T) Message { //nolint:ll msg := &ChannelUpdate2{ - Signature: RandSignature(t), ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( chainHashObj, ), @@ -541,10 +562,11 @@ func (c *ChannelUpdate2) RandTestMessage(t *rapid.T) Message { FeeProportionalMillionths: tlv.NewPrimitiveRecord[tlv.TlvType18, uint32]( feeProportionalMillionths, ), - ExtraOpaqueData: RandExtraOpaqueData(t, nil), + ExtraSignedFields: make(map[uint64][]byte), } - msg.Signature.ForceSchnorr() + msg.Signature.Val = RandSignature(t) + msg.Signature.Val.ForceSchnorr() if rapid.Bool().Draw(t, "isSecondPeer") { msg.SecondPeer = tlv.SomeRecordT( @@ -728,7 +750,7 @@ var _ TestMessage = (*CommitSig)(nil) // // This is part of the TestMessage interface. func (c *CommitSig) RandTestMessage(t *rapid.T) Message { - cr, _ := RandCustomRecords(t, nil, true) + cr, _ := RandCustomRecords(t, nil) sig := &CommitSig{ ChanID: RandChannelID(t), CommitSig: RandSignature(t), @@ -1606,7 +1628,7 @@ func (s *Shutdown) RandTestMessage(t *rapid.T) Message { shutdownNonce = SomeShutdownNonce(RandMusig2Nonce(t)) } - cr, _ := RandCustomRecords(t, nil, true) + cr, _ := RandCustomRecords(t, nil) return &Shutdown{ ChannelID: RandChannelID(t), @@ -1663,7 +1685,7 @@ func (c *UpdateAddHTLC) RandTestMessage(t *rapid.T) Message { numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") if numRecords > 0 { - msg.CustomRecords, _ = RandCustomRecords(t, nil, true) + msg.CustomRecords, _ = RandCustomRecords(t, nil) } // 50/50 chance to add a blinding point @@ -1744,7 +1766,7 @@ func (c *UpdateFulfillHTLC) RandTestMessage(t *rapid.T) Message { PaymentPreimage: RandPaymentPreimage(t), } - cr, ignoreRecords := RandCustomRecords(t, nil, true) + cr, ignoreRecords := RandCustomRecords(t, nil) msg.CustomRecords = cr randData := RandExtraOpaqueData(t, ignoreRecords) diff --git a/lnwire/test_utils.go b/lnwire/test_utils.go index 07c9d795b..4c88687d1 100644 --- a/lnwire/test_utils.go +++ b/lnwire/test_utils.go @@ -198,23 +198,37 @@ func RandNetAddrs(t *rapid.T) []net.Addr { } // RandCustomRecords generates random custom TLV records. -func RandCustomRecords(t *rapid.T, - ignoreRecords fn.Set[uint64], - custom bool) (CustomRecords, fn.Set[uint64]) { +func RandCustomRecords(t *rapid.T, ignoreRecords fn.Set[uint64]) (CustomRecords, + fn.Set[uint64]) { - numRecords := rapid.IntRange(0, 5).Draw(t, "numCustomRecords") + customRecords, set := RandTLVRecords( + t, ignoreRecords, MinCustomRecordsTlvType, + ) + + // Validate the custom records as a sanity check. + require.NoError(t, customRecords.Validate()) + + return customRecords, set +} + +// RandSignedRangeRecords generates a random set of signed records in the +// second "signed" tlv range for pure TLV messages. +func RandSignedRangeRecords(t *rapid.T) (CustomRecords, fn.Set[uint64]) { + return RandTLVRecords(t, nil, pureTLVSignedSecondRangeStart) +} + +// RandTLVRecords generates custom TLV records. +func RandTLVRecords(t *rapid.T, ignoreRecords fn.Set[uint64], + rangeStart int) (CustomRecords, fn.Set[uint64]) { + + numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") customRecords := make(CustomRecords) if numRecords == 0 { return nil, nil } - rangeStart := 0 - rangeStop := int(CustomTypeStart) - if custom { - rangeStart = 70_000 - rangeStop = 100_000 - } + rangeStop := rangeStart + 30_000 ignoreSet := fn.NewSet[uint64]() for i := 0; i < numRecords; i++ { @@ -258,7 +272,7 @@ func RandExtraOpaqueData(t *rapid.T, ignoreRecords fn.Set[uint64]) ExtraOpaqueData { // Make some random records. - cRecords, _ := RandCustomRecords(t, ignoreRecords, false) + cRecords, _ := RandTLVRecords(t, ignoreRecords, 0) if cRecords == nil { return ExtraOpaqueData{} } diff --git a/netann/channel_announcement.go b/netann/channel_announcement.go index 83ee55db6..9bb21c401 100644 --- a/netann/channel_announcement.go +++ b/netann/channel_announcement.go @@ -7,7 +7,9 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" @@ -104,7 +106,8 @@ func CreateChanAnnouncement(chanProof *models.ChannelAuthProof, // FetchPkScript defines a function that can be used to fetch the output script // for the transaction with the given SCID. -type FetchPkScript func(*lnwire.ShortChannelID) ([]byte, error) +type FetchPkScript func(lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) // ValidateChannelAnn validates the channel announcement. func ValidateChannelAnn(a lnwire.ChannelAnnouncement, @@ -198,24 +201,124 @@ func validateChannelAnn1(a *lnwire.ChannelAnnouncement1) error { func validateChannelAnn2(a *lnwire.ChannelAnnouncement2, fetchPkScript FetchPkScript) error { + // Next, we fetch the funding transaction's PK script. We need this so + // that we know what type of channel we will be validating: P2WSH or + // P2TR. + scriptClass, scriptAddr, err := fetchPkScript(a.ShortChannelID.Val) + if err != nil { + return err + } + + var keys []*btcec.PublicKey + + switch scriptClass { + case txscript.WitnessV0ScriptHashTy: + keys, err = chanAnn2P2WSHMuSig2Keys(a) + if err != nil { + return err + } + case txscript.WitnessV1TaprootTy: + keys, err = chanAnn2P2TRMuSig2Keys(a, scriptAddr) + if err != nil { + return err + } + default: + return fmt.Errorf("invalid on-chain pk script type for "+ + "channel_announcement_2: %s", scriptClass) + } + + // Do a MuSig2 aggregation of the keys to obtain the aggregate key that + // the signature will be validated against. + aggKey, _, _, err := musig2.AggregateKeys(keys, true) + if err != nil { + return err + } + + // Get the message that the signature should have signed. dataHash, err := ChanAnn2DigestToSign(a) if err != nil { return err } - sig, err := a.Signature.ToSignature() + // Obtain the signature. + sig, err := a.Signature.Val.ToSignature() if err != nil { return err } + // Check that the signature is valid for the aggregate key given the + // message digest. + if !sig.Verify(dataHash.CloneBytes(), aggKey.FinalKey) { + return fmt.Errorf("invalid sig") + } + + return nil +} + +// chanAnn2P2WSHMuSig2Keys returns the set of keys that should be used to +// construct the aggregate key that the signature in an +// lnwire.ChannelAnnouncement2 message should be verified against in the case +// where the channel being announced is a P2WSH channel. +func chanAnn2P2WSHMuSig2Keys(a *lnwire.ChannelAnnouncement2) ( + []*btcec.PublicKey, error) { + nodeKey1, err := btcec.ParsePubKey(a.NodeID1.Val[:]) if err != nil { - return err + return nil, err } nodeKey2, err := btcec.ParsePubKey(a.NodeID2.Val[:]) if err != nil { - return err + return nil, err + } + + btcKeyMissingErrString := "bitcoin key %d missing for announcement " + + "of a P2WSH channel" + + btcKey1Bytes, err := a.BitcoinKey1.UnwrapOrErr( + fmt.Errorf(btcKeyMissingErrString, 1), + ) + if err != nil { + return nil, err + } + + btcKey1, err := btcec.ParsePubKey(btcKey1Bytes.Val[:]) + if err != nil { + return nil, err + } + + btcKey2Bytes, err := a.BitcoinKey2.UnwrapOrErr( + fmt.Errorf(btcKeyMissingErrString, 2), + ) + if err != nil { + return nil, err + } + + btcKey2, err := btcec.ParsePubKey(btcKey2Bytes.Val[:]) + if err != nil { + return nil, err + } + + return []*btcec.PublicKey{ + nodeKey1, nodeKey2, btcKey1, btcKey2, + }, nil +} + +// chanAnn2P2TRMuSig2Keys returns the set of keys that should be used to +// construct the aggregate key that the signature in an +// lnwire.ChannelAnnouncement2 message should be verified against in the case +// where the channel being announced is a P2TR channel. +func chanAnn2P2TRMuSig2Keys(a *lnwire.ChannelAnnouncement2, + scriptAddr btcutil.Address) ([]*btcec.PublicKey, error) { + + nodeKey1, err := btcec.ParsePubKey(a.NodeID1.Val[:]) + if err != nil { + return nil, err + } + + nodeKey2, err := btcec.ParsePubKey(a.NodeID2.Val[:]) + if err != nil { + return nil, err } keys := []*btcec.PublicKey{ @@ -236,49 +339,36 @@ func validateChannelAnn2(a *lnwire.ChannelAnnouncement2, bitcoinKey1, err := btcec.ParsePubKey(btcKey1.Val[:]) if err != nil { - return err + return nil, err } bitcoinKey2, err := btcec.ParsePubKey(btcKey2.Val[:]) if err != nil { - return err + return nil, err } keys = append(keys, bitcoinKey1, bitcoinKey2) } else { - // If bitcoin keys are not provided, then we need to get the - // on-chain output key since this will be the 3rd key in the - // 3-of-3 MuSig2 signature. - pkScript, err := fetchPkScript(&a.ShortChannelID.Val) + // If bitcoin keys are not provided, then the on-chain output + // key is considered the 3rd key in the 3-of-3 MuSig2 signature. + outputKey, err := schnorr.ParsePubKey( + scriptAddr.ScriptAddress(), + ) if err != nil { - return err - } - - outputKey, err := schnorr.ParsePubKey(pkScript[2:]) - if err != nil { - return err + return nil, err } keys = append(keys, outputKey) } - aggKey, _, _, err := musig2.AggregateKeys(keys, true) - if err != nil { - return err - } - - if !sig.Verify(dataHash.CloneBytes(), aggKey.FinalKey) { - return fmt.Errorf("invalid sig") - } - - return nil + return keys, nil } // ChanAnn2DigestToSign computes the digest of the message to be signed. func ChanAnn2DigestToSign(a *lnwire.ChannelAnnouncement2) (*chainhash.Hash, error) { - data, err := a.DataToSign() + data, err := lnwire.SerialiseFieldsToSign(a) if err != nil { return nil, err } diff --git a/netann/channel_announcement_test.go b/netann/channel_announcement_test.go index fbd2b2d2b..38949e046 100644 --- a/netann/channel_announcement_test.go +++ b/netann/channel_announcement_test.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/input" @@ -73,20 +74,25 @@ func TestChanAnnounce2Validation(t *testing.T) { t.Parallel() t.Run( - "test 4-of-4 MuSig2 channel announcement", - test4of4MuSig2ChanAnnouncement, + "test 4-of-4 MuSig2 P2TR channel announcement", + test4of4MuSig2P2TRChanAnnouncement, ) t.Run( - "test 3-of-3 MuSig2 channel announcement", + "test 3-of-3 MuSig2 P2TR channel announcement", test3of3MuSig2ChanAnnouncement, ) + + t.Run( + "test 4-of-4 MuSig2 P2WSH channel announcement", + test4of4MuSig2P2WSHChanAnnouncement, + ) } -// test4of4MuSig2ChanAnnouncement covers the case where both bitcoin keys are -// present in the channel announcement. In this case, the signature should be -// a 4-of-4 MuSig2. -func test4of4MuSig2ChanAnnouncement(t *testing.T) { +// test4of4MuSig2P2TRChanAnnouncement covers the case where the funding +// transaction PK script is a P2WSH. In this case, the signature should be valid +// for the MuSig2 4-of-4 aggregation of the node keys and the bitcoin keys. +func test4of4MuSig2P2WSHChanAnnouncement(t *testing.T) { t.Parallel() // Generate the keys for node 1 and node2. @@ -159,10 +165,138 @@ func test4of4MuSig2ChanAnnouncement(t *testing.T) { sig, err := lnwire.NewSigFromSignature(s) require.NoError(t, err) - ann.Signature = sig + ann.Signature.Val = sig + + // Create an accurate representation of what the on-chain pk script will + // look like. For this case, it is only important that we get the + // correct script class. + multiSigScript, err := input.GenMultiSigScript( + node1.btcPub.SerializeCompressed(), + node2.btcPub.SerializeCompressed(), + ) + require.NoError(t, err) + + scriptHash, err := input.WitnessScriptHash(multiSigScript) + require.NoError(t, err) + pkAddr, err := btcutil.NewAddressScriptHash( + scriptHash, &chaincfg.MainNetParams, + ) + require.NoError(t, err) + + // Create a mock tx fetcher that returns the expected script class and + // pk address. + fetchTx := func(lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) { + + return txscript.WitnessV0ScriptHashTy, pkAddr, nil + } // Validate the announcement. - require.NoError(t, ValidateChannelAnn(ann, nil)) + require.NoError(t, ValidateChannelAnn(ann, fetchTx)) +} + +// test4of4MuSig2P2TRChanAnnouncement covers the case where both bitcoin keys +// are present in the channel announcement 2 and the funding transaction PK +// script is a P2TR. In this case, the signature should be a 4-of-4 MuSig2. +func test4of4MuSig2P2TRChanAnnouncement(t *testing.T) { + t.Parallel() + + // Generate the keys for node 1 and node2. + node1, node2 := genChanAnnKeys(t) + + // Build the unsigned channel announcement. + ann := buildUnsignedChanAnnouncement(node1, node2, true) + + // Serialise the bytes that need to be signed. + msg, err := ChanAnn2DigestToSign(ann) + require.NoError(t, err) + + var msgBytes [32]byte + copy(msgBytes[:], msg.CloneBytes()) + + // Generate the 4 nonces required for producing the signature. + var ( + node1NodeNonce = genNonceForPubKey(t, node1.nodePub) + node1BtcNonce = genNonceForPubKey(t, node1.btcPub) + node2NodeNonce = genNonceForPubKey(t, node2.nodePub) + node2BtcNonce = genNonceForPubKey(t, node2.btcPub) + ) + + nonceAgg, err := musig2.AggregateNonces([][66]byte{ + node1NodeNonce.PubNonce, + node1BtcNonce.PubNonce, + node2NodeNonce.PubNonce, + node2BtcNonce.PubNonce, + }) + require.NoError(t, err) + + pubKeys := []*btcec.PublicKey{ + node1.nodePub, node2.nodePub, node1.btcPub, node2.btcPub, + } + + // Let Node1 sign the announcement message with its node key. + psA1, err := musig2.Sign( + node1NodeNonce.SecNonce, node1.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node1 sign the announcement message with its bitcoin key. + psA2, err := musig2.Sign( + node1BtcNonce.SecNonce, node1.btcPriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node2 sign the announcement message with its node key. + psB1, err := musig2.Sign( + node2NodeNonce.SecNonce, node2.nodePriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Let Node2 sign the announcement message with its bitcoin key. + psB2, err := musig2.Sign( + node2BtcNonce.SecNonce, node2.btcPriv, nonceAgg, pubKeys, + msgBytes, musig2.WithSortedKeys(), + ) + require.NoError(t, err) + + // Finally, combine the partial signatures from Node1 and Node2 and add + // the signature to the announcement message. + s := musig2.CombineSigs(psA1.R, []*musig2.PartialSignature{ + psA1, psA2, psB1, psB2, + }) + + sig, err := lnwire.NewSigFromSignature(s) + require.NoError(t, err) + + ann.Signature.Val = sig + + // Create an accurate representation of what the on-chain pk script will + // look like. For this case, it is only important that we get the + // correct script class. + combinedKey, _, _, err := musig2.AggregateKeys( + []*btcec.PublicKey{node1.btcPub, node2.btcPub}, true, + ) + require.NoError(t, err) + + pkAddr, err := btcutil.NewAddressTaproot( + combinedKey.FinalKey.SerializeCompressed()[1:], + &chaincfg.MainNetParams, + ) + require.NoError(t, err) + + // Create a mock tx fetcher that returns the expected script class and + // pk address. + fetchTx := func(lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) { + + return txscript.WitnessV1TaprootTy, pkAddr, nil + } + + // Validate the announcement. + require.NoError(t, ValidateChannelAnn(ann, fetchTx)) } // test3of3MuSig2ChanAnnouncement covers the case where no bitcoin keys are @@ -217,14 +351,17 @@ func test3of3MuSig2ChanAnnouncement(t *testing.T) { }) require.NoError(t, err) - pkScript, err := input.PayToTaprootScript(outputKey) + pkAddr, err := btcutil.NewAddressTaproot( + outputKey.SerializeCompressed()[1:], &chaincfg.MainNetParams, + ) require.NoError(t, err) - // We'll pass in a mock tx fetcher that will return the funding output - // containing this key. This is needed since the output key can not be - // determined from the channel announcement itself. - fetchTx := func(chanID *lnwire.ShortChannelID) ([]byte, error) { - return pkScript, nil + // Create a mock tx fetcher that returns the expected script class + // and pk address. + fetchTx := func(lnwire.ShortChannelID) (txscript.ScriptClass, + btcutil.Address, error) { + + return txscript.WitnessV1TaprootTy, pkAddr, nil } pubKeys := []*btcec.PublicKey{node1.nodePub, node2.nodePub, outputKey} @@ -259,7 +396,7 @@ func test3of3MuSig2ChanAnnouncement(t *testing.T) { sig, err := lnwire.NewSigFromSignature(s) require.NoError(t, err) - ann.Signature = sig + ann.Signature.Val = sig // Validate the announcement. require.NoError(t, ValidateChannelAnn(ann, fetchTx)) diff --git a/netann/channel_update.go b/netann/channel_update.go index c0adc81a2..7e87fd77b 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -242,7 +242,7 @@ func verifyChannelUpdate2Signature(c *lnwire.ChannelUpdate2, return fmt.Errorf("unable to reconstruct message data: %w", err) } - nodeSig, err := c.Signature.ToSignature() + nodeSig, err := c.Signature.Val.ToSignature() if err != nil { return err } @@ -330,7 +330,7 @@ func ChanUpdate2DigestTag() []byte { // chanUpdate2DigestToSign computes the digest of the ChannelUpdate2 message to // be signed. func chanUpdate2DigestToSign(c *lnwire.ChannelUpdate2) ([]byte, error) { - data, err := c.DataToSign() + data, err := lnwire.SerialiseFieldsToSign(c) if err != nil { return nil, err } diff --git a/server.go b/server.go index 83dd9a4d4..44be180ea 100644 --- a/server.go +++ b/server.go @@ -1063,7 +1063,7 @@ func newServer(ctx context.Context, cfg *Config, listenAddrs []net.Addr, Graph: s.graphBuilder, ChainIO: s.cc.ChainIO, Notifier: s.cc.ChainNotifier, - ChainHash: *s.cfg.ActiveNetParams.GenesisHash, + ChainParams: s.cfg.ActiveNetParams.Params, Broadcast: s.BroadcastMessage, ChanSeries: chanSeries, NotifyWhenOnline: s.NotifyWhenOnline,