From 3d88017b382f8fc411a2347da7620c9d7370e0d0 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 2 Jan 2024 18:45:39 -0800 Subject: [PATCH] lnwire: add new closing_complete and closing_sig messages These two messages will be used to implement the new and improved co-op closing protocol. This PR also show cases how to use the new `tlv.OptionalRecord` type to define and handle TLV level parsing. I think we can make one additional helper function to clean up some of the boiler plate for the encode/decode. --- lnwire/closing_complete.go | 156 +++++++++++++++++++++++++++++++++++++ lnwire/closing_sig.go | 76 ++++++++++++++++++ lnwire/lnwire_test.go | 114 +++++++++++++++++++++++++++ lnwire/message.go | 10 +++ 4 files changed, 356 insertions(+) create mode 100644 lnwire/closing_complete.go create mode 100644 lnwire/closing_sig.go diff --git a/lnwire/closing_complete.go b/lnwire/closing_complete.go new file mode 100644 index 000000000..c3cd0cc4d --- /dev/null +++ b/lnwire/closing_complete.go @@ -0,0 +1,156 @@ +package lnwire + +import ( + "bytes" + "io" + + "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/tlv" +) + +// ClosingSigs houses the 3 possible signatures that can be sent when +// attempting to complete a cooperative channel closure. A signature will +// either include both outputs, or only one of the outputs from either side. +type ClosingSigs struct { + // CloserNoClosee is a signature that excludes the output of the + // clsoee. + CloserNoClosee tlv.OptionalRecordT[tlv.TlvType1, Sig] + + // NoCloserClosee is a signature that excludes the output of the + // closer. + NoCloserClosee tlv.OptionalRecordT[tlv.TlvType2, Sig] + + // CloserAndClosee is a signature that includes both outputs. + CloserAndClosee tlv.OptionalRecordT[tlv.TlvType3, Sig] +} + +// ClosingComplete is sent by either side to kick off the process of obtaining +// a valid signature on a c o-operative channel closure of their choice. +type ClosingComplete struct { + // ChannelID serves to identify which channel is to be closed. + ChannelID ChannelID + + // FeeSatoshis is the total fee in satoshis that the party to the + // channel would like to propose for the close transaction. + FeeSatoshis btcutil.Amount + + // Sequence is the sequence number to be used in the input spending the + // funding transaction. + Sequence uint32 + + // ClosingSigs houses the 3 possible signatures that can be sent. + ClosingSigs + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData +} + +// decodeClosingSigs decodes the closing sig TLV records in the passed +// ExtraOpaqueData. +func decodeClosingSigs(c *ClosingSigs, tlvRecords ExtraOpaqueData) error { + sig1 := tlv.ZeroRecordT[tlv.TlvType1, Sig]() + sig2 := tlv.ZeroRecordT[tlv.TlvType2, Sig]() + sig3 := tlv.ZeroRecordT[tlv.TlvType3, Sig]() + + typeMap, err := tlvRecords.ExtractRecords(&sig1, &sig2, &sig3) + if err != nil { + return err + } + + // TODO(roasbeef): helper func to made decode of the optional vals + // easier? + + if val, ok := typeMap[c.CloserNoClosee.TlvType()]; ok && val == nil { + c.CloserNoClosee = tlv.SomeRecordT(sig1) + } + if val, ok := typeMap[c.NoCloserClosee.TlvType()]; ok && val == nil { + c.NoCloserClosee = tlv.SomeRecordT(sig2) + } + if val, ok := typeMap[c.CloserAndClosee.TlvType()]; ok && val == nil { + c.CloserAndClosee = tlv.SomeRecordT(sig3) + } + + return nil +} + +// Decode deserializes a serialized ClosingComplete message stored in the +// passed io.Reader. +func (c *ClosingComplete) Decode(r io.Reader, _ uint32) error { + // First, read out all the fields that are hard coded into the message. + err := ReadElements(r, &c.ChannelID, &c.FeeSatoshis, &c.Sequence) + if err != nil { + return err + } + + // With the hard coded messages read, we'll now read out the TLV fields + // of the message. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + if err := decodeClosingSigs(&c.ClosingSigs, tlvRecords); err != nil { + return err + } + + if len(tlvRecords) != 0 { + c.ExtraData = tlvRecords + } + + return nil +} + +// closingSigRecords returns the set of records that encode the closing sigs, +// if present. +func closingSigRecords(c *ClosingSigs) []tlv.RecordProducer { + recordProducers := make([]tlv.RecordProducer, 0, 3) + c.CloserNoClosee.WhenSome(func(sig tlv.RecordT[tlv.TlvType1, Sig]) { + recordProducers = append(recordProducers, &sig) + }) + c.NoCloserClosee.WhenSome(func(sig tlv.RecordT[tlv.TlvType2, Sig]) { + recordProducers = append(recordProducers, &sig) + }) + c.CloserAndClosee.WhenSome(func(sig tlv.RecordT[tlv.TlvType3, Sig]) { + recordProducers = append(recordProducers, &sig) + }) + + return recordProducers +} + +// Encode serializes the target ClosingComplete into the passed io.Writer. +func (c *ClosingComplete) Encode(w *bytes.Buffer, _ uint32) error { + if err := WriteChannelID(w, c.ChannelID); err != nil { + return err + } + + if err := WriteSatoshi(w, c.FeeSatoshis); err != nil { + return err + } + + if err := WriteUint32(w, c.Sequence); err != nil { + return err + } + + recordProducers := closingSigRecords(&c.ClosingSigs) + + err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) + if err != nil { + return err + } + + return WriteBytes(w, c.ExtraData) +} + +// MsgType returns the uint32 code which uniquely identifies this message as a +// ClosingComplete message on the wire. +// +// This is part of the lnwire.Message interface. +func (c *ClosingComplete) MsgType() MessageType { + return MsgClosingComplete +} + +// A compile time check to ensure ClosingComplete implements the lnwire.Message +// interface. +var _ Message = (*ClosingComplete)(nil) diff --git a/lnwire/closing_sig.go b/lnwire/closing_sig.go new file mode 100644 index 000000000..df160d12e --- /dev/null +++ b/lnwire/closing_sig.go @@ -0,0 +1,76 @@ +package lnwire + +import ( + "bytes" + "io" +) + +// ClosingSig is sent in response to a ClosingComplete message. It carries the +// signatures of the closee to the closer. +type ClosingSig struct { + // ChannelID serves to identify which channel is to be closed. + ChannelID ChannelID + + // ClosingSigs houses the 3 possible signatures that can be sent. + ClosingSigs + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData +} + +// Decode deserializes a serialized ClosingSig message stored in the passed +// io.Reader. +func (c *ClosingSig) Decode(r io.Reader, _ uint32) error { + // First, read out all the fields that are hard coded into the message. + err := ReadElements(r, &c.ChannelID) + if err != nil { + return err + } + + // With the hard coded messages read, we'll now read out the TLV fields + // of the message. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + if err := decodeClosingSigs(&c.ClosingSigs, tlvRecords); err != nil { + return err + } + + if len(tlvRecords) != 0 { + c.ExtraData = tlvRecords + } + + return nil +} + +// Encode serializes the target ClosingSig into the passed io.Writer. +func (c *ClosingSig) Encode(w *bytes.Buffer, _ uint32) error { + if err := WriteChannelID(w, c.ChannelID); err != nil { + return err + } + + recordProducers := closingSigRecords(&c.ClosingSigs) + + err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) + if err != nil { + return err + } + + return WriteBytes(w, c.ExtraData) +} + +// MsgType returns the uint32 code which uniquely identifies this message as a +// ClosingSig message on the wire. +// +// This is part of the lnwire.Message interface. +func (c *ClosingSig) MsgType() MessageType { + return MsgClosingSig +} + +// A compile time check to ensure ClosingSig implements the lnwire.Message +// interface. +var _ Message = (*ClosingSig)(nil) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index c3248c3c9..1246f92e7 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -22,6 +22,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tor" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1212,6 +1213,107 @@ func TestLightningWireProtocol(t *testing.T) { PaddingBytes: paddingBytes, } + v[0] = reflect.ValueOf(req) + }, + MsgClosingComplete: func(v []reflect.Value, r *rand.Rand) { + var c [32]byte + _, err := r.Read(c[:]) + if err != nil { + t.Fatalf("unable to generate chan id: %v", + err) + return + } + + req := ClosingComplete{ + ChannelID: ChannelID(c), + FeeSatoshis: btcutil.Amount(r.Int63()), + Sequence: uint32(r.Int63()), + ClosingSigs: ClosingSigs{}, + } + + if r.Intn(2) == 0 { + sig := tlv.ZeroRecordT[tlv.TlvType1, Sig]() + _, err := r.Read(sig.Val.bytes[:]) + if err != nil { + t.Fatalf("unable to generate sig: %v", + err) + return + } + + req.CloserNoClosee = tlv.SomeRecordT(sig) + } + if r.Intn(2) == 0 { + sig := tlv.ZeroRecordT[tlv.TlvType2, Sig]() + _, err := r.Read(sig.Val.bytes[:]) + if err != nil { + t.Fatalf("unable to generate sig: %v", + err) + return + } + + req.NoCloserClosee = tlv.SomeRecordT(sig) + } + if r.Intn(2) == 0 { + sig := tlv.ZeroRecordT[tlv.TlvType3, Sig]() + _, err := r.Read(sig.Val.bytes[:]) + if err != nil { + t.Fatalf("unable to generate sig: %v", + err) + return + } + + req.CloserAndClosee = tlv.SomeRecordT(sig) + } + + v[0] = reflect.ValueOf(req) + }, + MsgClosingSig: func(v []reflect.Value, r *rand.Rand) { + var c [32]byte + _, err := r.Read(c[:]) + if err != nil { + t.Fatalf("unable to generate chan id: %v", err) + return + } + + req := ClosingSig{ + ChannelID: ChannelID(c), + ClosingSigs: ClosingSigs{}, + } + + if r.Intn(2) == 0 { + sig := tlv.ZeroRecordT[tlv.TlvType1, Sig]() + _, err := r.Read(sig.Val.bytes[:]) + if err != nil { + t.Fatalf("unable to generate sig: %v", + err) + return + } + + req.CloserNoClosee = tlv.SomeRecordT(sig) + } + if r.Intn(2) == 0 { + sig := tlv.ZeroRecordT[tlv.TlvType2, Sig]() + _, err := r.Read(sig.Val.bytes[:]) + if err != nil { + t.Fatalf("unable to generate sig: %v", + err) + return + } + + req.NoCloserClosee = tlv.SomeRecordT(sig) + } + if r.Intn(2) == 0 { + sig := tlv.ZeroRecordT[tlv.TlvType3, Sig]() + _, err := r.Read(sig.Val.bytes[:]) + if err != nil { + t.Fatalf("unable to generate sig: %v", + err) + return + } + + req.CloserAndClosee = tlv.SomeRecordT(sig) + } + v[0] = reflect.ValueOf(req) }, } @@ -1424,6 +1526,18 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: MsgClosingComplete, + scenario: func(m ClosingComplete) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgClosingSig, + scenario: func(m ClosingSig) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { var config *quick.Config diff --git a/lnwire/message.go b/lnwire/message.go index bc79ed003..634f78ab3 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -34,6 +34,8 @@ const ( MsgChannelReady = 36 MsgShutdown = 38 MsgClosingSigned = 39 + MsgClosingComplete = 40 + MsgClosingSig = 41 MsgDynPropose = 111 MsgDynAck = 113 MsgDynReject = 115 @@ -146,6 +148,10 @@ func (t MessageType) String() string { return "ReplyChannelRange" case MsgGossipTimestampRange: return "GossipTimestampRange" + case MsgClosingComplete: + return "ClosingComplete" + case MsgClosingSig: + return "ClosingSig" default: return "" } @@ -256,6 +262,10 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &ReplyChannelRange{} case MsgGossipTimestampRange: msg = &GossipTimestampRange{} + case MsgClosingComplete: + msg = &ClosingComplete{} + case MsgClosingSig: + msg = &ClosingSig{} default: // If the message is not within our custom range and has not // specifically been overridden, return an unknown message.