diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index f5786b918..2f7957ff6 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -2,11 +2,11 @@ package lnwire import ( "bytes" - "fmt" "io" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/tlv" ) // AcceptChannel is the message Bob sends to Alice after she initiates the @@ -95,6 +95,10 @@ type AcceptChannel struct { // and its length followed by the script will be written if it is set. UpfrontShutdownScript DeliveryAddress + // ChannelType is the explicit channel type the initiator wishes to + // open. + ChannelType *ChannelType + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -117,11 +121,11 @@ var _ Message = (*AcceptChannel)(nil) // // This is part of the lnwire.Message interface. func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error { - // Since the upfront script is encoded as a TLV record, concatenate it - // with the ExtraData, and write them as one. - tlvRecords, err := packShutdownScript( - a.UpfrontShutdownScript, a.ExtraData, - ) + recordProducers := []tlv.RecordProducer{&a.UpfrontShutdownScript} + if a.ChannelType != nil { + recordProducers = append(recordProducers, a.ChannelType) + } + err := EncodeMessageExtraData(&a.ExtraData, recordProducers...) if err != nil { return err } @@ -182,7 +186,7 @@ func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, tlvRecords) + return WriteBytes(w, a.ExtraData) } // Decode deserializes the serialized AcceptChannel stored in the passed @@ -220,74 +224,26 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { return err } - a.UpfrontShutdownScript, a.ExtraData, err = parseShutdownScript( - tlvRecords, + // Next we'll parse out the set of known records, keeping the raw tlv + // bytes untouched to ensure we don't drop any bytes erroneously. + var chanType ChannelType + typeMap, err := tlvRecords.ExtractRecords( + &a.UpfrontShutdownScript, &chanType, ) if err != nil { return err } + // Set the corresponding TLV types if they were included in the stream. + if val, ok := typeMap[ChannelTypeRecordType]; ok && val == nil { + a.ChannelType = &chanType + } + + a.ExtraData = tlvRecords + return nil } -// packShutdownScript takes an upfront shutdown script and an opaque data blob -// and concatenates them. -func packShutdownScript(addr DeliveryAddress, extraData ExtraOpaqueData) ( - ExtraOpaqueData, error) { - - // We'll always write the upfront shutdown script record, regardless of - // the script being empty. - var tlvRecords ExtraOpaqueData - - // Pack it into a data blob as a TLV record. - err := tlvRecords.PackRecords(addr.NewRecord()) - if err != nil { - return nil, fmt.Errorf("unable to pack upfront shutdown "+ - "script as TLV record: %v", err) - } - - // Concatenate the remaining blob with the shutdown script record. - tlvRecords = append(tlvRecords, extraData...) - return tlvRecords, nil -} - -// parseShutdownScript reads and extract the upfront shutdown script from the -// passe data blob. It returns the script, if any, and the remainder of the -// data blob. -// -// This can be used to parse extra data for the OpenChannel and AcceptChannel -// messages, where the shutdown script is mandatory if extra TLV data is -// present. -func parseShutdownScript(tlvRecords ExtraOpaqueData) (DeliveryAddress, - ExtraOpaqueData, error) { - - // If no TLV data is present there can't be any script available. - if len(tlvRecords) == 0 { - return nil, tlvRecords, nil - } - - // Otherwise the shutdown script MUST be present. - var addr DeliveryAddress - tlvs, err := tlvRecords.ExtractRecords(addr.NewRecord()) - if err != nil { - return nil, nil, err - } - - // Not among TLV records, this means the data was invalid. - if _, ok := tlvs[DeliveryAddrType]; !ok { - return nil, nil, fmt.Errorf("no shutdown script in non-empty " + - "data blob") - } - - // Now that we have retrieved the address (which can be zero-length), - // we'll remove the bytes encoding it from the TLV data before - // returning it. - addrLen := len(addr) - tlvRecords = tlvRecords[addrLen+2:] - - return addr, tlvRecords, nil -} - // MsgType returns the MessageType code which uniquely identifies this message // as an AcceptChannel on the wire. // diff --git a/lnwire/channel_type_test.go b/lnwire/channel_type_test.go index 0edef0779..dd8e02439 100644 --- a/lnwire/channel_type_test.go +++ b/lnwire/channel_type_test.go @@ -17,10 +17,10 @@ func TestChannelTypeEncodeDecode(t *testing.T) { )) var extraData ExtraOpaqueData - require.NoError(t, extraData.PackRecords(chanType.Record())) + require.NoError(t, extraData.PackRecords(&chanType)) var chanType2 ChannelType - tlvs, err := extraData.ExtractRecords(chanType2.Record()) + tlvs, err := extraData.ExtractRecords(&chanType2) require.NoError(t, err) require.Contains(t, tlvs, ChannelTypeRecordType) diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go index 70554f4f5..88b914c38 100644 --- a/lnwire/extra_bytes.go +++ b/lnwire/extra_bytes.go @@ -2,6 +2,7 @@ package lnwire import ( "bytes" + "fmt" "io" "io/ioutil" @@ -50,7 +51,17 @@ func (e *ExtraOpaqueData) Decode(r io.Reader) error { // PackRecords attempts to encode the set of tlv records into the target // ExtraOpaqueData instance. The records will be encoded as a raw TLV stream // and stored within the backing slice pointer. -func (e *ExtraOpaqueData) PackRecords(records ...tlv.Record) error { +func (e *ExtraOpaqueData) PackRecords(recordProducers ...tlv.RecordProducer) error { + // First, assemble all the records passed in in series. + records := make([]tlv.Record, 0, len(recordProducers)) + for _, producer := range recordProducers { + records = append(records, producer.Record()) + } + + // Ensure that the set of records are sorted before we encode them into + // the stream, to ensure they're canonical. + tlv.SortRecords(records) + tlvStream, err := tlv.NewStream(records...) if err != nil { return err @@ -70,9 +81,15 @@ func (e *ExtraOpaqueData) PackRecords(records ...tlv.Record) error { // it were a tlv stream. The set of raw parsed types is returned, and any // passed records (if found in the stream) will be parsed into the proper // tlv.Record. -func (e *ExtraOpaqueData) ExtractRecords(records ...tlv.Record) ( +func (e *ExtraOpaqueData) ExtractRecords(recordProducers ...tlv.RecordProducer) ( tlv.TypeMap, error) { + // First, assemble all the records passed in in series. + records := make([]tlv.Record, 0, len(recordProducers)) + for _, producer := range recordProducers { + records = append(records, producer.Record()) + } + extraBytesReader := bytes.NewReader(*e) tlvStream, err := tlv.NewStream(records...) @@ -82,3 +99,19 @@ func (e *ExtraOpaqueData) ExtractRecords(records ...tlv.Record) ( return tlvStream.DecodeWithParsedTypes(extraBytesReader) } + +// EncodeMessageExtraData encodes the given recordProducers into the given +// extraData. +func EncodeMessageExtraData(extraData *ExtraOpaqueData, + recordProducers ...tlv.RecordProducer) error { + + // Treat extraData as a mutable reference. + if extraData == nil { + return fmt.Errorf("extra data cannot be nil") + } + + // Pack in the series of TLV records into this message. The order we + // pass them in doesn't matter, as the method will ensure that things + // are all properly sorted. + return extraData.PackRecords(recordProducers...) +} diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go index 39271d6aa..88ffcc307 100644 --- a/lnwire/extra_bytes_test.go +++ b/lnwire/extra_bytes_test.go @@ -86,6 +86,14 @@ func TestExtraOpaqueDataEncodeDecode(t *testing.T) { } } +type recordProducer struct { + record tlv.Record +} + +func (r *recordProducer) Record() tlv.Record { + return r.record +} + // TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of // tlv.Records into a stream, and unpack them on the other side to obtain the // same set of records. @@ -102,23 +110,23 @@ func TestExtraOpaqueDataPackUnpackRecords(t *testing.T) { hop1 uint32 = 99 hop2 uint32 ) - testRecords := []tlv.Record{ - tlv.MakePrimitiveRecord(type1, &channelType1), - tlv.MakePrimitiveRecord(type2, &hop1), + testRecordsProducers := []tlv.RecordProducer{ + &recordProducer{tlv.MakePrimitiveRecord(type1, &channelType1)}, + &recordProducer{tlv.MakePrimitiveRecord(type2, &hop1)}, } // Now that we have our set of sample records and types, we'll encode // them into the passed ExtraOpaqueData instance. var extraBytes ExtraOpaqueData - if err := extraBytes.PackRecords(testRecords...); err != nil { + if err := extraBytes.PackRecords(testRecordsProducers...); err != nil { t.Fatalf("unable to pack records: %v", err) } // We'll now simulate decoding these types _back_ into records on the // other side. - newRecords := []tlv.Record{ - tlv.MakePrimitiveRecord(type1, &channelType2), - tlv.MakePrimitiveRecord(type2, &hop2), + newRecords := []tlv.RecordProducer{ + &recordProducer{tlv.MakePrimitiveRecord(type1, &channelType2)}, + &recordProducer{tlv.MakePrimitiveRecord(type2, &hop2)}, } typeMap, err := extraBytes.ExtractRecords(newRecords...) if err != nil { diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 4475b3382..5143d74bd 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -18,8 +18,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/tor" + "github.com/stretchr/testify/assert" ) var ( @@ -284,9 +284,7 @@ func TestLightningWireProtocol(t *testing.T) { t.Fatalf("unable to read msg: %v", err) return false } - if !reflect.DeepEqual(msg, newMsg) { - t.Fatalf("messages don't match after re-encoding: %v "+ - "vs %v", spew.Sdump(msg), spew.Sdump(newMsg)) + if !assert.Equalf(t, msg, newMsg, "message mismatch") { return false } @@ -369,17 +367,16 @@ func TestLightningWireProtocol(t *testing.T) { t.Fatalf("unable to generate delivery address: %v", err) return } + + req.ChannelType = new(ChannelType) + *req.ChannelType = ChannelType(*randRawFeatureVector(r)) } else { req.UpfrontShutdownScript = []byte{} } - // 1/2 chance how having more TLV data after the - // shutdown script. + // 1/2 chance additional TLV data. if r.Intn(2) == 0 { - // TLV type 1 of length 2. - req.ExtraData = []byte{1, 2, 0xff, 0xff} - } else { - req.ExtraData = []byte{} + req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} } v[0] = reflect.ValueOf(req) @@ -439,16 +436,16 @@ func TestLightningWireProtocol(t *testing.T) { t.Fatalf("unable to generate delivery address: %v", err) return } + + req.ChannelType = new(ChannelType) + *req.ChannelType = ChannelType(*randRawFeatureVector(r)) } else { req.UpfrontShutdownScript = []byte{} } - // 1/2 chance how having more TLV data after the - // shutdown script. + + // 1/2 chance additional TLV data. if r.Intn(2) == 0 { - // TLV type 1 of length 2. - req.ExtraData = []byte{1, 2, 0xff, 0xff} - } else { - req.ExtraData = []byte{} + req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} } v[0] = reflect.ValueOf(req) diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 035297ba2..534f17ee9 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/tlv" ) // FundingFlag represents the possible bit mask values for the ChannelFlags @@ -130,6 +131,10 @@ type OpenChannel struct { // and its length followed by the script will be written if it is set. UpfrontShutdownScript DeliveryAddress + // ChannelType is the explicit channel type the initiator wishes to + // open. + ChannelType *ChannelType + // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can // be used to specify optional data such as custom TLV fields. @@ -150,13 +155,12 @@ var _ Message = (*OpenChannel)(nil) // implementation. Serialization will observe the rules defined by the passed // protocol version. // -// This is part of the lnwire.Message interface. func (o *OpenChannel) Encode(w *bytes.Buffer, pver uint32) error { - // Since the upfront script is encoded as a TLV record, concatenate it - // with the ExtraData, and write them as one. - tlvRecords, err := packShutdownScript( - o.UpfrontShutdownScript, o.ExtraData, - ) + recordProducers := []tlv.RecordProducer{&o.UpfrontShutdownScript} + if o.ChannelType != nil { + recordProducers = append(recordProducers, o.ChannelType) + } + err := EncodeMessageExtraData(&o.ExtraData, recordProducers...) if err != nil { return err } @@ -234,7 +238,7 @@ func (o *OpenChannel) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, tlvRecords) + return WriteBytes(w, o.ExtraData) } // Decode deserializes the serialized OpenChannel stored in the passed @@ -276,13 +280,23 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { return err } - o.UpfrontShutdownScript, o.ExtraData, err = parseShutdownScript( - tlvRecords, + // Next we'll parse out the set of known records, keeping the raw tlv + // bytes untouched to ensure we don't drop any bytes erroneously. + var chanType ChannelType + typeMap, err := tlvRecords.ExtractRecords( + &o.UpfrontShutdownScript, &chanType, ) if err != nil { return err } + // Set the corresponding TLV types if they were included in the stream. + if val, ok := typeMap[ChannelTypeRecordType]; ok && val == nil { + o.ChannelType = &chanType + } + + o.ExtraData = tlvRecords + return nil } diff --git a/lnwire/typed_delivery_addr.go b/lnwire/typed_delivery_addr.go index 9ad53b1ae..90a101b34 100644 --- a/lnwire/typed_delivery_addr.go +++ b/lnwire/typed_delivery_addr.go @@ -24,11 +24,11 @@ const ( // p2wpkh. type DeliveryAddress []byte -// NewRecord returns a TLV record that can be used to encode the delivery -// address within the ExtraData TLV stream. This was intorudced in order to +// Record returns a TLV record that can be used to encode the delivery +// address within the ExtraData TLV stream. This was introduced in order to // allow the OpenChannel/AcceptChannel messages to properly be extended with // TLV types. -func (d *DeliveryAddress) NewRecord() tlv.Record { +func (d *DeliveryAddress) Record() tlv.Record { addrBytes := (*[]byte)(d) return tlv.MakeDynamicRecord( diff --git a/lnwire/typed_delivery_addr_test.go b/lnwire/typed_delivery_addr_test.go index d5d9c703a..9d00bc8bd 100644 --- a/lnwire/typed_delivery_addr_test.go +++ b/lnwire/typed_delivery_addr_test.go @@ -15,13 +15,13 @@ func TestDeliveryAddressEncodeDecode(t *testing.T) { ) var extraData ExtraOpaqueData - err := extraData.PackRecords(addr.NewRecord()) + err := extraData.PackRecords(&addr) if err != nil { t.Fatal(err) } var addr2 DeliveryAddress - tlvs, err := extraData.ExtractRecords(addr2.NewRecord()) + tlvs, err := extraData.ExtractRecords(&addr2) if err != nil { t.Fatal(err) }