diff --git a/lnwire/channel_update_2.go b/lnwire/channel_update_2.go new file mode 100644 index 000000000..c69ef47da --- /dev/null +++ b/lnwire/channel_update_2.go @@ -0,0 +1,516 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + defaultCltvExpiryDelta = uint16(80) + defaultHtlcMinMsat = MilliSatoshi(1) + defaultFeeBaseMsat = uint32(1000) + defaultFeeProportionalMillionths = uint32(1) + + // chanUpdate2MsgName is a string representing the name of the + // ChannelUpdate2 message. This string will be used during the + // construction of the tagged hash message to be signed when producing + // the signature for the ChannelUpdate2 message. + chanUpdate2MsgName = "channel_update_2" + + // chanUpdate2SigField is the name of the signature field of the + // ChannelUpdate2 message. This string will be used during the + // construction of the tagged hash message to be signed when producing + // the signature for the ChannelUpdate2 message. + chanUpdate2SigField = "signature" +) + +// ChannelUpdate2 message is used after taproot channel has been initially +// announced. Each side independently announces its fees and minimum expiry for +// HTLCs and other parameters. This message is also used to redeclare initially +// set channel parameters. +type ChannelUpdate2 struct { + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature Sig + + // ChainHash denotes the target chain that this channel was opened + // within. This value should be the genesis hash of the target chain. + // Along with the short channel ID, this uniquely identifies the + // channel globally in a blockchain. + ChainHash tlv.RecordT[tlv.TlvType0, chainhash.Hash] + + // ShortChannelID is the unique description of the funding transaction. + ShortChannelID tlv.RecordT[tlv.TlvType2, ShortChannelID] + + // BlockHeight allows ordering in the case of multiple announcements. We + // should ignore the message if block height is not greater than the + // last-received. The block height must always be greater or equal to + // the block height that the channel funding transaction was confirmed + // in. + BlockHeight tlv.RecordT[tlv.TlvType4, uint32] + + // DisabledFlags is an optional bitfield that describes various reasons + // that the node is communicating that the channel should be considered + // disabled. + DisabledFlags tlv.RecordT[tlv.TlvType6, ChanUpdateDisableFlags] + + // SecondPeer is used to indicate which node the channel node has + // created and signed this message. If this field is absent, it was + // node 2 otherwise it was node 1. + SecondPeer tlv.OptionalRecordT[tlv.TlvType8, TrueBoolean] + + // CLTVExpiryDelta is the minimum number of blocks this node requires to + // be added to the expiry of HTLCs. This is a security parameter + // determined by the node operator. This value represents the required + // gap between the time locks of the incoming and outgoing HTLC's set + // to this node. + CLTVExpiryDelta tlv.RecordT[tlv.TlvType10, uint16] + + // HTLCMinimumMsat is the minimum HTLC value which will be accepted. + HTLCMinimumMsat tlv.RecordT[tlv.TlvType12, MilliSatoshi] + + // HtlcMaximumMsat is the maximum HTLC value which will be accepted. + HTLCMaximumMsat tlv.RecordT[tlv.TlvType14, MilliSatoshi] + + // FeeBaseMsat is the base fee that must be used for incoming HTLC's to + // this particular channel. This value will be tacked onto the required + // for a payment independent of the size of the payment. + FeeBaseMsat tlv.RecordT[tlv.TlvType16, uint32] + + // FeeProportionalMillionths is the fee rate that will be charged per + // millionth of a satoshi. + FeeProportionalMillionths tlv.RecordT[tlv.TlvType18, uint32] + + // ExtraOpaqueData is the set of data that was appended to this message + // to fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraOpaqueData ExtraOpaqueData +} + +// Decode deserializes a serialized ChannelUpdate2 stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Decode(r io.Reader, _ uint32) error { + err := ReadElement(r, &c.Signature) + if err != nil { + return err + } + c.Signature.ForceSchnorr() + + return c.DecodeTLVRecords(r) +} + +// DecodeTLVRecords decodes only the TLV section of the message. +func (c *ChannelUpdate2) DecodeTLVRecords(r io.Reader) error { + // First extract into extra opaque data. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + var ( + chainHash = tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + secondPeer = tlv.ZeroRecordT[tlv.TlvType8, TrueBoolean]() + ) + typeMap, err := tlvRecords.ExtractRecords( + &chainHash, &c.ShortChannelID, &c.BlockHeight, &c.DisabledFlags, + &secondPeer, &c.CLTVExpiryDelta, &c.HTLCMinimumMsat, + &c.HTLCMaximumMsat, &c.FeeBaseMsat, + &c.FeeProportionalMillionths, + ) + if err != nil { + return err + } + + // By default, the chain-hash is the bitcoin mainnet genesis block hash. + c.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if _, ok := typeMap[c.ChainHash.TlvType()]; ok { + c.ChainHash.Val = chainHash.Val + } + + // The presence of the second_peer tlv type indicates "true". + if _, ok := typeMap[c.SecondPeer.TlvType()]; ok { + c.SecondPeer = tlv.SomeRecordT(secondPeer) + } + + // If the CLTV expiry delta was not encoded, then set it to the default + // value. + if _, ok := typeMap[c.CLTVExpiryDelta.TlvType()]; !ok { + c.CLTVExpiryDelta.Val = defaultCltvExpiryDelta + } + + // If the HTLC Minimum msat was not encoded, then set it to the default + // value. + if _, ok := typeMap[c.HTLCMinimumMsat.TlvType()]; !ok { + c.HTLCMinimumMsat.Val = defaultHtlcMinMsat + } + + // If the base fee was not encoded, then set it to the default value. + if _, ok := typeMap[c.FeeBaseMsat.TlvType()]; !ok { + c.FeeBaseMsat.Val = defaultFeeBaseMsat + } + + // If the proportional fee was not encoded, then set it to the default + // value. + if _, ok := typeMap[c.FeeProportionalMillionths.TlvType()]; !ok { + c.FeeProportionalMillionths.Val = defaultFeeProportionalMillionths //nolint:lll + } + + if len(tlvRecords) != 0 { + c.ExtraOpaqueData = tlvRecords + } + + return nil +} + +// Encode serializes the target ChannelUpdate2 into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) Encode(w *bytes.Buffer, _ uint32) error { + _, err := w.Write(c.Signature.RawBytes()) + if err != nil { + return err + } + + _, err = c.DataToSign() + if err != nil { + return err + } + + return WriteBytes(w, c.ExtraOpaqueData) +} + +// DigestTag returns the tag to be used when signing the digest. +func (c *ChannelUpdate2) DigestTag() []byte { + return MsgTag(chanUpdate2MsgName, chanUpdate2SigField) +} + +// DigestToSign computes the digest of the message to be signed. +func (c *ChannelUpdate2) DigestToSign() ([]byte, error) { + data, err := c.DataToSign() + if err != nil { + return nil, err + } + + hash := MsgHash(chanUpdate2MsgName, chanUpdate2SigField, data) + + return hash[:], nil +} + +// DataToSign is used to retrieve part of the announcement message which should +// be signed. For the ChannelUpdate2 message, this includes the serialised TLV +// records. +func (c *ChannelUpdate2) DataToSign() ([]byte, error) { + // The chain-hash record is only included if it is _not_ equal to the + // bitcoin mainnet genisis block hash. + var recordProducers []tlv.RecordProducer + if !c.ChainHash.Val.IsEqual(chaincfg.MainNetParams.GenesisHash) { + hash := tlv.ZeroRecordT[tlv.TlvType0, [32]byte]() + hash.Val = c.ChainHash.Val + + recordProducers = append(recordProducers, &hash) + } + + recordProducers = append(recordProducers, + &c.ShortChannelID, &c.BlockHeight, + ) + + // Only include the disable flags if any bit is set. + if !c.DisabledFlags.Val.IsEnabled() { + recordProducers = append(recordProducers, &c.DisabledFlags) + } + + // We only need to encode the second peer boolean if it is true + c.SecondPeer.WhenSome(func(r tlv.RecordT[tlv.TlvType8, TrueBoolean]) { + recordProducers = append(recordProducers, &r) + }) + + // We only encode the cltv expiry delta if it is not equal to the + // default. + if c.CLTVExpiryDelta.Val != defaultCltvExpiryDelta { + recordProducers = append(recordProducers, &c.CLTVExpiryDelta) + } + + if c.HTLCMinimumMsat.Val != defaultHtlcMinMsat { + recordProducers = append(recordProducers, &c.HTLCMinimumMsat) + } + + recordProducers = append(recordProducers, &c.HTLCMaximumMsat) + + if c.FeeBaseMsat.Val != defaultFeeBaseMsat { + recordProducers = append(recordProducers, &c.FeeBaseMsat) + } + + if c.FeeProportionalMillionths.Val != defaultFeeProportionalMillionths { + recordProducers = append( + recordProducers, &c.FeeProportionalMillionths, + ) + } + + err := EncodeMessageExtraData(&c.ExtraOpaqueData, recordProducers...) + if err != nil { + return nil, err + } + + return c.ExtraOpaqueData, nil +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ChannelUpdate2) MsgType() MessageType { + return MsgChannelUpdate2 +} + +func (c *ChannelUpdate2) ExtraData() ExtraOpaqueData { + return c.ExtraOpaqueData +} + +// A compile time check to ensure ChannelUpdate2 implements the +// lnwire.Message interface. +var _ Message = (*ChannelUpdate2)(nil) + +// SCID returns the ShortChannelID of the channel that the update applies to. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) SCID() ShortChannelID { + return c.ShortChannelID.Val +} + +// IsNode1 is true if the update was produced by node 1 of the channel peers. +// Node 1 is the node with the lexicographically smaller public key. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) IsNode1() bool { + return c.SecondPeer.IsNone() +} + +// IsDisabled is true if the update is announcing that the channel should be +// considered disabled. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) IsDisabled() bool { + return !c.DisabledFlags.Val.IsEnabled() +} + +// GetChainHash returns the hash of the chain that the message is referring to. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) GetChainHash() chainhash.Hash { + return c.ChainHash.Val +} + +// ForwardingPolicy returns the set of forwarding constraints of the update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) ForwardingPolicy() *ForwardingPolicy { + return &ForwardingPolicy{ + TimeLockDelta: c.CLTVExpiryDelta.Val, + BaseFee: MilliSatoshi(c.FeeBaseMsat.Val), + FeeRate: MilliSatoshi(c.FeeProportionalMillionths.Val), + MinHTLC: c.HTLCMinimumMsat.Val, + HasMaxHTLC: true, + MaxHTLC: c.HTLCMaximumMsat.Val, + } +} + +// CmpAge can be used to determine if the update is older or newer than the +// passed update. It returns 1 if this update is newer, -1 if it is older, and +// 0 if they are the same age. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) CmpAge(update ChannelUpdate) (CompareResult, error) { + other, ok := update.(*ChannelUpdate2) + if !ok { + return 0, fmt.Errorf("expected *ChannelUpdate2, got: %T", + update) + } + + switch { + case c.BlockHeight.Val > other.BlockHeight.Val: + return GreaterThan, nil + case c.BlockHeight.Val < other.BlockHeight.Val: + return LessThan, nil + default: + return EqualTo, nil + } +} + +// SetDisabledFlag can be used to adjust the disabled flag of an update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) SetDisabledFlag(disabled bool) { + if disabled { + c.DisabledFlags.Val |= ChanUpdateDisableIncoming + c.DisabledFlags.Val |= ChanUpdateDisableOutgoing + } else { + c.DisabledFlags.Val &^= ChanUpdateDisableIncoming + c.DisabledFlags.Val &^= ChanUpdateDisableOutgoing + } +} + +// SetSCID can be used to overwrite the SCID of the update. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) SetSCID(scid ShortChannelID) { + c.ShortChannelID.Val = scid +} + +// Validate validates the sanity of the channel update message +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) Validate(capacity btcutil.Amount) error { + maxHtlc := c.HTLCMaximumMsat.Val + if maxHtlc == 0 || maxHtlc < c.HTLCMinimumMsat.Val { + return fmt.Errorf("invalid max htlc for channel update %v", + spew.Sdump(c)) + } + + // Checking whether the MaxHTLC value respects the channel's + // capacity. + capacityMsat := NewMSatFromSatoshis(capacity) + if maxHtlc > capacityMsat { + return fmt.Errorf("max_htlc (%v) for channel update greater "+ + "than capacity (%v)", maxHtlc, capacityMsat) + } + + return nil +} + +// VerifySig verifies that the message was signed by the given pub key. +// +// NOTE: this is part of the ChannelUpdate interface. +func (c *ChannelUpdate2) VerifySig(pubKey *btcec.PublicKey) error { + digest, err := c.DigestToSign() + if err != nil { + return fmt.Errorf("unable to reconstruct message data: %w", err) + } + + nodeSig, err := c.Signature.ToSignature() + if err != nil { + return err + } + + if !nodeSig.Verify(digest, pubKey) { + return fmt.Errorf("invalid signature for channel update %v", + spew.Sdump(c)) + } + + return nil +} + +// A compile time check to ensure ChannelUpdate2 implements the +// lnwire.ChannelUpdate interface. +var _ ChannelUpdate = (*ChannelUpdate2)(nil) + +// ChanUpdateDisableFlags is a bit vector that can be used to indicate various +// reasons for the channel being marked as disabled. +type ChanUpdateDisableFlags uint8 + +const ( + // ChanUpdateDisableIncoming is a bit indicates that a channel is + // disabled in the inbound direction meaning that the node broadcasting + // the update is communicating that they cannot receive funds. + ChanUpdateDisableIncoming ChanUpdateDisableFlags = 1 << iota + + // ChanUpdateDisableOutgoing is a bit indicates that a channel is + // disabled in the outbound direction meaning that the node broadcasting + // the update is communicating that they cannot send or route funds. + ChanUpdateDisableOutgoing = 2 +) + +// IncomingDisabled returns true if the ChanUpdateDisableIncoming bit is set. +func (c ChanUpdateDisableFlags) IncomingDisabled() bool { + return c&ChanUpdateDisableIncoming == ChanUpdateDisableIncoming +} + +// OutgoingDisabled returns true if the ChanUpdateDisableOutgoing bit is set. +func (c ChanUpdateDisableFlags) OutgoingDisabled() bool { + return c&ChanUpdateDisableOutgoing == ChanUpdateDisableOutgoing +} + +// IsEnabled returns true if none of the disable bits are set. +func (c ChanUpdateDisableFlags) IsEnabled() bool { + return c == 0 +} + +// String returns the bitfield flags as a string. +func (c ChanUpdateDisableFlags) String() string { + return fmt.Sprintf("%08b", c) +} + +// Record returns the tlv record for the disable flags. +func (c *ChanUpdateDisableFlags) Record() tlv.Record { + return tlv.MakeStaticRecord(0, c, 1, encodeDisableFlags, + decodeDisableFlags) +} + +func encodeDisableFlags(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*ChanUpdateDisableFlags); ok { + flagsInt := uint8(*v) + + return tlv.EUint8(w, &flagsInt, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.ChanUpdateDisableFlags") +} + +func decodeDisableFlags(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*ChanUpdateDisableFlags); ok { + var flagsInt uint8 + err := tlv.DUint8(r, &flagsInt, buf, l) + if err != nil { + return err + } + + *v = ChanUpdateDisableFlags(flagsInt) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "lnwire.ChanUpdateDisableFlags", + l, l) +} + +// TrueBoolean is a record that indicates true or false using the presence of +// the record. If the record is absent, it indicates false. If it is presence, +// it indicates true. +type TrueBoolean struct{} + +// Record returns the tlv record for the boolean entry. +func (b *TrueBoolean) Record() tlv.Record { + return tlv.MakeStaticRecord( + 0, b, 0, booleanEncoder, booleanDecoder, + ) +} + +func booleanEncoder(_ io.Writer, val interface{}, _ *[8]byte) error { + if _, ok := val.(*TrueBoolean); ok { + return nil + } + + return tlv.NewTypeForEncodingErr(val, "TrueBoolean") +} + +func booleanDecoder(_ io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if _, ok := val.(*TrueBoolean); ok && (l == 0 || l == 1) { + return nil + } + + return tlv.NewTypeForEncodingErr(val, "TrueBoolean") +} diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 2f09becc0..19c584448 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -1512,6 +1512,83 @@ func TestLightningWireProtocol(t *testing.T) { require.NoError(t, err) } + v[0] = reflect.ValueOf(req) + }, + MsgChannelUpdate2: func(v []reflect.Value, r *rand.Rand) { + req := ChannelUpdate2{ + Signature: testSchnorrSig, + ExtraOpaqueData: make([]byte, 0), + } + + req.ShortChannelID.Val = NewShortChanIDFromInt( + uint64(r.Int63()), + ) + req.BlockHeight.Val = r.Uint32() + req.HTLCMaximumMsat.Val = MilliSatoshi(r.Uint64()) + + // Sometimes set chain hash to bitcoin mainnet genesis + // hash. + req.ChainHash.Val = *chaincfg.MainNetParams.GenesisHash + if r.Int31()%2 == 0 { + _, err := r.Read(req.ChainHash.Val[:]) + require.NoError(t, err) + } + + // Sometimes use default htlc min msat. + req.HTLCMinimumMsat.Val = defaultHtlcMinMsat + if r.Int31()%2 == 0 { + req.HTLCMinimumMsat.Val = MilliSatoshi( + r.Uint64(), + ) + } + + // Sometimes set the cltv expiry delta to the default. + req.CLTVExpiryDelta.Val = defaultCltvExpiryDelta + if r.Int31()%2 == 0 { + req.CLTVExpiryDelta.Val = uint16(r.Int31()) + } + + // Sometimes use default fee base. + req.FeeBaseMsat.Val = defaultFeeBaseMsat + if r.Int31()%2 == 0 { + req.FeeBaseMsat.Val = r.Uint32() + } + + // Sometimes use default proportional fee. + req.FeeProportionalMillionths.Val = + defaultFeeProportionalMillionths + if r.Int31()%2 == 0 { + req.FeeProportionalMillionths.Val = r.Uint32() + } + + // Alternate between the two direction possibilities. + if r.Int31()%2 == 0 { + req.SecondPeer = tlv.SomeRecordT( + tlv.ZeroRecordT[tlv.TlvType8, TrueBoolean](), //nolint:lll + ) + } + + // Sometimes set the incoming disabled flag. + if r.Int31()%2 == 0 { + req.DisabledFlags.Val |= + ChanUpdateDisableIncoming + } + + // Sometimes set the outgoing disabled flag. + if r.Int31()%2 == 0 { + req.DisabledFlags.Val |= + ChanUpdateDisableOutgoing + } + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + req.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read(req.ExtraOpaqueData[:]) + require.NoError(t, err) + } + v[0] = reflect.ValueOf(req) }, } @@ -1754,6 +1831,12 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: MsgChannelUpdate2, + scenario: func(m ChannelUpdate2) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { var config *quick.Config diff --git a/lnwire/message.go b/lnwire/message.go index b91db0679..a758db000 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -59,6 +59,7 @@ const ( MsgReplyChannelRange = 264 MsgGossipTimestampRange = 265 MsgChannelAnnouncement2 = 267 + MsgChannelUpdate2 = 271 MsgKickoffSig = 777 ) @@ -161,6 +162,8 @@ func (t MessageType) String() string { return "MsgAnnounceSignatures2" case MsgChannelAnnouncement2: return "ChannelAnnouncement2" + case MsgChannelUpdate2: + return "ChannelUpdate2" default: return "" } @@ -294,6 +297,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &AnnounceSignatures2{} case MsgChannelAnnouncement2: msg = &ChannelAnnouncement2{} + case MsgChannelUpdate2: + msg = &ChannelUpdate2{} default: // If the message is not within our custom range and has not // specifically been overridden, return an unknown message. diff --git a/lnwire/msat.go b/lnwire/msat.go index 7473d72c8..2966e5ddb 100644 --- a/lnwire/msat.go +++ b/lnwire/msat.go @@ -2,8 +2,10 @@ package lnwire import ( "fmt" + "io" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -49,3 +51,40 @@ func (m MilliSatoshi) String() string { } // TODO(roasbeef): extend with arithmetic operations? + +// Record returns a TLV record that can be used to encode/decode a MilliSatoshi +// to/from a TLV stream. +func (m *MilliSatoshi) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, m, tlv.SizeBigSize(m), encodeMilliSatoshis, + decodeMilliSatoshis, + ) +} + +func encodeMilliSatoshis(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*MilliSatoshi); ok { + bigSize := uint64(*v) + + return tlv.EBigSize(w, &bigSize, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.MilliSatoshi") +} + +func decodeMilliSatoshis(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*MilliSatoshi); ok { + var bigSize uint64 + err := tlv.DBigSize(r, &bigSize, buf, l) + if err != nil { + return err + } + + *v = MilliSatoshi(bigSize) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "lnwire.MilliSatoshi", l, l) +}