diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 8e11c8699..3e3651964 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -5,6 +5,7 @@ import ( "io" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/tlv" ) // ClosingSigned is sent by both parties to a channel once the channel is clear @@ -29,6 +30,14 @@ type ClosingSigned struct { // Signature is for the proposed channel close transaction. Signature Sig + // PartialSig is used to transmit a musig2 extended partial signature + // that signs the latest fee offer. The nonce isn't sent along side, as + // that has already been sent in the initial shutdown message. + // + // NOTE: This field is only populated if a musig2 taproot channel is + // being signed for. In this case, the above Sig type MUST be blank. + PartialSig *PartialSig + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -55,9 +64,36 @@ var _ Message = (*ClosingSigned)(nil) // // This is part of the lnwire.Message interface. func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { - return ReadElements( - r, &c.ChannelID, &c.FeeSatoshis, &c.Signature, &c.ExtraData, + err := ReadElements( + r, &c.ChannelID, &c.FeeSatoshis, &c.Signature, ) + if err != nil { + return err + } + + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var ( + partialSig PartialSig + ) + typeMap, err := tlvRecords.ExtractRecords(&partialSig) + if err != nil { + return err + } + + // Set the corresponding TLV types if they were included in the stream. + if val, ok := typeMap[PartialSigRecordType]; ok && val == nil { + c.PartialSig = &partialSig + } + + if len(tlvRecords) != 0 { + c.ExtraData = tlvRecords + } + + return nil } // Encode serializes the target ClosingSigned into the passed io.Writer @@ -65,6 +101,15 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error { + recordProducers := make([]tlv.RecordProducer, 0, 1) + if c.PartialSig != nil { + recordProducers = append(recordProducers, c.PartialSig) + } + err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) + if err != nil { + return err + } + if err := WriteChannelID(w, c.ChannelID); err != nil { return err } diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index 2adb6a082..93cc50e39 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -3,8 +3,56 @@ package lnwire import ( "bytes" "io" + + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/lightningnetwork/lnd/tlv" ) +const ( + // ShutdownNonceRecordType is the type of the shutdown nonce TLV record. + ShutdownNonceRecordType = 8 +) + +// ShutdownNonce is the type of the nonce we send during the shutdown flow. +// Unlike the other nonces, this nonce is symmetric w.r.t the message being +// signed (there's only one message for shutdown: the co-op close txn). +type ShutdownNonce Musig2Nonce + +// Record returns a TLV record that can be used to encode/decode the musig2 +// nonce from a given TLV stream. +func (s *ShutdownNonce) Record() tlv.Record { + return tlv.MakeStaticRecord( + ShutdownNonceRecordType, s, musig2.PubNonceSize, + shutdownNonceTypeEncoder, shutdownNonceTypeDecoder, + ) +} + +// shutdownNonceTypeEncoder is a custom TLV encoder for the Musig2Nonce type. +func shutdownNonceTypeEncoder(w io.Writer, val interface{}, + buf *[8]byte) error { + + if v, ok := val.(*ShutdownNonce); ok { + _, err := w.Write(v[:]) + return err + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.Musig2Nonce") +} + +// shutdownNonceTypeDecoder is a custom TLV decoder for the Musig2Nonce record. +func shutdownNonceTypeDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*ShutdownNonce); ok { + _, err := io.ReadFull(r, v[:]) + return err + } + + return tlv.NewTypeForDecodingErr( + val, "lnwire.ShutdownNonce", l, musig2.PubNonceSize, + ) +} + // Shutdown is sent by either side in order to initiate the cooperative closure // of a channel. This message is sparse as both sides implicitly have the // information necessary to construct a transaction that will send the settled @@ -17,6 +65,10 @@ type Shutdown struct { // Address is the script to which the channel funds will be paid. Address DeliveryAddress + // ShutdownNonce is the nonce the sender will use to sign the first + // co-op sign offer. + ShutdownNonce *ShutdownNonce + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -40,7 +92,32 @@ var _ Message = (*Shutdown)(nil) // // This is part of the lnwire.Message interface. func (s *Shutdown) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &s.ChannelID, &s.Address, &s.ExtraData) + err := ReadElements(r, &s.ChannelID, &s.Address) + if err != nil { + return err + } + + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var musigNonce ShutdownNonce + typeMap, err := tlvRecords.ExtractRecords(&musigNonce) + if err != nil { + return err + } + + // Set the corresponding TLV types if they were included in the stream. + if val, ok := typeMap[ShutdownNonceRecordType]; ok && val == nil { + s.ShutdownNonce = &musigNonce + } + + if len(tlvRecords) != 0 { + s.ExtraData = tlvRecords + } + + return nil } // Encode serializes the target Shutdown into the passed io.Writer observing @@ -48,6 +125,15 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { + recordProducers := make([]tlv.RecordProducer, 0, 1) + if s.ShutdownNonce != nil { + recordProducers = append(recordProducers, s.ShutdownNonce) + } + err := EncodeMessageExtraData(&s.ExtraData, recordProducers...) + if err != nil { + return err + } + if err := WriteChannelID(w, s.ChannelID); err != nil { return err }