lnwire: add new closing_complete and closing_sig messages

These two messages will be used to implement the new and improved co-op
closing protocol. This PR also show cases how to use the new
`tlv.OptionalRecord` type to define and handle TLV level parsing.

I think we can make one additional helper function to clean up some of
the boiler plate for the encode/decode.
This commit is contained in:
Olaoluwa Osuntokun
2024-01-02 18:45:39 -08:00
parent 34fd35bc63
commit 3d88017b38
4 changed files with 356 additions and 0 deletions

156
lnwire/closing_complete.go Normal file
View File

@@ -0,0 +1,156 @@
package lnwire
import (
"bytes"
"io"
"github.com/btcsuite/btcd/btcutil"
"github.com/lightningnetwork/lnd/tlv"
)
// ClosingSigs houses the 3 possible signatures that can be sent when
// attempting to complete a cooperative channel closure. A signature will
// either include both outputs, or only one of the outputs from either side.
type ClosingSigs struct {
// CloserNoClosee is a signature that excludes the output of the
// clsoee.
CloserNoClosee tlv.OptionalRecordT[tlv.TlvType1, Sig]
// NoCloserClosee is a signature that excludes the output of the
// closer.
NoCloserClosee tlv.OptionalRecordT[tlv.TlvType2, Sig]
// CloserAndClosee is a signature that includes both outputs.
CloserAndClosee tlv.OptionalRecordT[tlv.TlvType3, Sig]
}
// ClosingComplete is sent by either side to kick off the process of obtaining
// a valid signature on a c o-operative channel closure of their choice.
type ClosingComplete struct {
// ChannelID serves to identify which channel is to be closed.
ChannelID ChannelID
// FeeSatoshis is the total fee in satoshis that the party to the
// channel would like to propose for the close transaction.
FeeSatoshis btcutil.Amount
// Sequence is the sequence number to be used in the input spending the
// funding transaction.
Sequence uint32
// ClosingSigs houses the 3 possible signatures that can be sent.
ClosingSigs
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// decodeClosingSigs decodes the closing sig TLV records in the passed
// ExtraOpaqueData.
func decodeClosingSigs(c *ClosingSigs, tlvRecords ExtraOpaqueData) error {
sig1 := tlv.ZeroRecordT[tlv.TlvType1, Sig]()
sig2 := tlv.ZeroRecordT[tlv.TlvType2, Sig]()
sig3 := tlv.ZeroRecordT[tlv.TlvType3, Sig]()
typeMap, err := tlvRecords.ExtractRecords(&sig1, &sig2, &sig3)
if err != nil {
return err
}
// TODO(roasbeef): helper func to made decode of the optional vals
// easier?
if val, ok := typeMap[c.CloserNoClosee.TlvType()]; ok && val == nil {
c.CloserNoClosee = tlv.SomeRecordT(sig1)
}
if val, ok := typeMap[c.NoCloserClosee.TlvType()]; ok && val == nil {
c.NoCloserClosee = tlv.SomeRecordT(sig2)
}
if val, ok := typeMap[c.CloserAndClosee.TlvType()]; ok && val == nil {
c.CloserAndClosee = tlv.SomeRecordT(sig3)
}
return nil
}
// Decode deserializes a serialized ClosingComplete message stored in the
// passed io.Reader.
func (c *ClosingComplete) Decode(r io.Reader, _ uint32) error {
// First, read out all the fields that are hard coded into the message.
err := ReadElements(r, &c.ChannelID, &c.FeeSatoshis, &c.Sequence)
if err != nil {
return err
}
// With the hard coded messages read, we'll now read out the TLV fields
// of the message.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
if err := decodeClosingSigs(&c.ClosingSigs, tlvRecords); err != nil {
return err
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// closingSigRecords returns the set of records that encode the closing sigs,
// if present.
func closingSigRecords(c *ClosingSigs) []tlv.RecordProducer {
recordProducers := make([]tlv.RecordProducer, 0, 3)
c.CloserNoClosee.WhenSome(func(sig tlv.RecordT[tlv.TlvType1, Sig]) {
recordProducers = append(recordProducers, &sig)
})
c.NoCloserClosee.WhenSome(func(sig tlv.RecordT[tlv.TlvType2, Sig]) {
recordProducers = append(recordProducers, &sig)
})
c.CloserAndClosee.WhenSome(func(sig tlv.RecordT[tlv.TlvType3, Sig]) {
recordProducers = append(recordProducers, &sig)
})
return recordProducers
}
// Encode serializes the target ClosingComplete into the passed io.Writer.
func (c *ClosingComplete) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, c.ChannelID); err != nil {
return err
}
if err := WriteSatoshi(w, c.FeeSatoshis); err != nil {
return err
}
if err := WriteUint32(w, c.Sequence); err != nil {
return err
}
recordProducers := closingSigRecords(&c.ClosingSigs)
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// ClosingComplete message on the wire.
//
// This is part of the lnwire.Message interface.
func (c *ClosingComplete) MsgType() MessageType {
return MsgClosingComplete
}
// A compile time check to ensure ClosingComplete implements the lnwire.Message
// interface.
var _ Message = (*ClosingComplete)(nil)

76
lnwire/closing_sig.go Normal file
View File

@@ -0,0 +1,76 @@
package lnwire
import (
"bytes"
"io"
)
// ClosingSig is sent in response to a ClosingComplete message. It carries the
// signatures of the closee to the closer.
type ClosingSig struct {
// ChannelID serves to identify which channel is to be closed.
ChannelID ChannelID
// ClosingSigs houses the 3 possible signatures that can be sent.
ClosingSigs
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// Decode deserializes a serialized ClosingSig message stored in the passed
// io.Reader.
func (c *ClosingSig) Decode(r io.Reader, _ uint32) error {
// First, read out all the fields that are hard coded into the message.
err := ReadElements(r, &c.ChannelID)
if err != nil {
return err
}
// With the hard coded messages read, we'll now read out the TLV fields
// of the message.
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
if err := decodeClosingSigs(&c.ClosingSigs, tlvRecords); err != nil {
return err
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target ClosingSig into the passed io.Writer.
func (c *ClosingSig) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, c.ChannelID); err != nil {
return err
}
recordProducers := closingSigRecords(&c.ClosingSigs)
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}
// MsgType returns the uint32 code which uniquely identifies this message as a
// ClosingSig message on the wire.
//
// This is part of the lnwire.Message interface.
func (c *ClosingSig) MsgType() MessageType {
return MsgClosingSig
}
// A compile time check to ensure ClosingSig implements the lnwire.Message
// interface.
var _ Message = (*ClosingSig)(nil)

View File

@@ -22,6 +22,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/tor"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -1212,6 +1213,107 @@ func TestLightningWireProtocol(t *testing.T) {
PaddingBytes: paddingBytes, PaddingBytes: paddingBytes,
} }
v[0] = reflect.ValueOf(req)
},
MsgClosingComplete: func(v []reflect.Value, r *rand.Rand) {
var c [32]byte
_, err := r.Read(c[:])
if err != nil {
t.Fatalf("unable to generate chan id: %v",
err)
return
}
req := ClosingComplete{
ChannelID: ChannelID(c),
FeeSatoshis: btcutil.Amount(r.Int63()),
Sequence: uint32(r.Int63()),
ClosingSigs: ClosingSigs{},
}
if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType1, Sig]()
_, err := r.Read(sig.Val.bytes[:])
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
req.CloserNoClosee = tlv.SomeRecordT(sig)
}
if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType2, Sig]()
_, err := r.Read(sig.Val.bytes[:])
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
req.NoCloserClosee = tlv.SomeRecordT(sig)
}
if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType3, Sig]()
_, err := r.Read(sig.Val.bytes[:])
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
req.CloserAndClosee = tlv.SomeRecordT(sig)
}
v[0] = reflect.ValueOf(req)
},
MsgClosingSig: func(v []reflect.Value, r *rand.Rand) {
var c [32]byte
_, err := r.Read(c[:])
if err != nil {
t.Fatalf("unable to generate chan id: %v", err)
return
}
req := ClosingSig{
ChannelID: ChannelID(c),
ClosingSigs: ClosingSigs{},
}
if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType1, Sig]()
_, err := r.Read(sig.Val.bytes[:])
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
req.CloserNoClosee = tlv.SomeRecordT(sig)
}
if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType2, Sig]()
_, err := r.Read(sig.Val.bytes[:])
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
req.NoCloserClosee = tlv.SomeRecordT(sig)
}
if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType3, Sig]()
_, err := r.Read(sig.Val.bytes[:])
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
req.CloserAndClosee = tlv.SomeRecordT(sig)
}
v[0] = reflect.ValueOf(req) v[0] = reflect.ValueOf(req)
}, },
} }
@@ -1424,6 +1526,18 @@ func TestLightningWireProtocol(t *testing.T) {
return mainScenario(&m) return mainScenario(&m)
}, },
}, },
{
msgType: MsgClosingComplete,
scenario: func(m ClosingComplete) bool {
return mainScenario(&m)
},
},
{
msgType: MsgClosingSig,
scenario: func(m ClosingSig) bool {
return mainScenario(&m)
},
},
} }
for _, test := range tests { for _, test := range tests {
var config *quick.Config var config *quick.Config

View File

@@ -34,6 +34,8 @@ const (
MsgChannelReady = 36 MsgChannelReady = 36
MsgShutdown = 38 MsgShutdown = 38
MsgClosingSigned = 39 MsgClosingSigned = 39
MsgClosingComplete = 40
MsgClosingSig = 41
MsgDynPropose = 111 MsgDynPropose = 111
MsgDynAck = 113 MsgDynAck = 113
MsgDynReject = 115 MsgDynReject = 115
@@ -146,6 +148,10 @@ func (t MessageType) String() string {
return "ReplyChannelRange" return "ReplyChannelRange"
case MsgGossipTimestampRange: case MsgGossipTimestampRange:
return "GossipTimestampRange" return "GossipTimestampRange"
case MsgClosingComplete:
return "ClosingComplete"
case MsgClosingSig:
return "ClosingSig"
default: default:
return "<unknown>" return "<unknown>"
} }
@@ -256,6 +262,10 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
msg = &ReplyChannelRange{} msg = &ReplyChannelRange{}
case MsgGossipTimestampRange: case MsgGossipTimestampRange:
msg = &GossipTimestampRange{} msg = &GossipTimestampRange{}
case MsgClosingComplete:
msg = &ClosingComplete{}
case MsgClosingSig:
msg = &ClosingSig{}
default: default:
// If the message is not within our custom range and has not // If the message is not within our custom range and has not
// specifically been overridden, return an unknown message. // specifically been overridden, return an unknown message.