From f090a64142bcb4e39ade5d14245c08da33d4ca24 Mon Sep 17 00:00:00 2001 From: Carla Kirk-Cohen Date: Mon, 6 Nov 2023 15:36:31 -0500 Subject: [PATCH] multi: add blinding point to payment descriptor and persist This commit adds an optional blinding point to payment descriptors and persists them in our HTLC's extra data. A get/set pattern is used to populate the ExtraData on our disk representation of the HTLC so that callers do not need to worry about the underlying storage detail. --- channeldb/channel.go | 67 ++++++++++++++++++++++++++++++- channeldb/channel_test.go | 53 +++++++++++++----------- lnwallet/channel.go | 84 +++++++++++++++++++++++++++++++++------ lnwallet/channel_test.go | 69 ++++++++++++++++++++++++++++++-- lnwire/update_add_htlc.go | 13 ++++++ 5 files changed, 246 insertions(+), 40 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index e0c6c630d..18db1d207 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -35,6 +35,10 @@ const ( // begins to be interpreted as an absolute block height, rather than a // relative one. AbsoluteThawHeightThreshold uint32 = 500000 + + // HTLCBlindingPointTLV is the tlv type used for storing blinding + // points with HTLCs. + HTLCBlindingPointTLV tlv.Type = 0 ) var ( @@ -2316,7 +2320,56 @@ type HTLC struct { // Note that this extra data is stored inline with the OnionBlob for // legacy reasons, see serialization/deserialization functions for // detail. - ExtraData []byte + ExtraData lnwire.ExtraOpaqueData + + // BlindingPoint is an optional blinding point included with the HTLC. + // + // Note: this field is not a part of on-disk representation of the + // HTLC. It is stored in the ExtraData field, which is used to store + // a TLV stream of additional information associated with the HTLC. + BlindingPoint lnwire.BlindingPointRecord +} + +// serializeExtraData encodes a TLV stream of extra data to be stored with a +// HTLC. It uses the update_add_htlc TLV types, because this is where extra +// data is passed with a HTLC. At present blinding points are the only extra +// data that we will store, and the function is a no-op if a nil blinding +// point is provided. +// +// This function MUST be called to persist all HTLC values when they are +// serialized. +func (h *HTLC) serializeExtraData() error { + var records []tlv.RecordProducer + h.BlindingPoint.WhenSome(func(b tlv.RecordT[lnwire.BlindingPointTlvType, + *btcec.PublicKey]) { + + records = append(records, &b) + }) + + return h.ExtraData.PackRecords(records...) +} + +// deserializeExtraData extracts TLVs from the extra data persisted for the +// htlc and populates values in the struct accordingly. +// +// This function MUST be called to populate the struct properly when HTLCs +// are deserialized. +func (h *HTLC) deserializeExtraData() error { + if len(h.ExtraData) == 0 { + return nil + } + + blindingPoint := h.BlindingPoint.Zero() + tlvMap, err := h.ExtraData.ExtractRecords(&blindingPoint) + if err != nil { + return err + } + + if val, ok := tlvMap[h.BlindingPoint.TlvType()]; ok && val == nil { + h.BlindingPoint = tlv.SomeRecordT(blindingPoint) + } + + return nil } // SerializeHtlcs writes out the passed set of HTLC's into the passed writer @@ -2340,6 +2393,12 @@ func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { } for _, htlc := range htlcs { + // Populate TLV stream for any additional fields contained + // in the TLV. + if err := htlc.serializeExtraData(); err != nil { + return err + } + // The onion blob and hltc data are stored as a single var // bytes blob. onionAndExtraData := make( @@ -2425,6 +2484,12 @@ func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { onionAndExtraData[lnwire.OnionPacketSize:], ) } + + // Finally, deserialize any TLVs contained in that extra data + // if they are present. + if err := htlcs[i].deserializeExtraData(); err != nil { + return nil, err + } } return htlcs, nil diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 8be0005dc..981ddf688 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -23,6 +23,7 @@ import ( "github.com/lightningnetwork/lnd/lntest/channels" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -1606,9 +1607,25 @@ func TestHTLCsExtraData(t *testing.T) { OnionBlob: lnmock.MockOnion(), } + // Add a blinding point to a htlc. + blindingPointHTLC := HTLC{ + Signature: testSig.Serialize(), + Incoming: false, + Amt: 10, + RHash: key, + RefundTimeout: 1, + OnionBlob: lnmock.MockOnion(), + BlindingPoint: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + pubKey, + ), + ), + } + testCases := []struct { - name string - htlcs []HTLC + name string + htlcs []HTLC + blindingIdx int }{ { // Serialize multiple HLTCs with no extra data to @@ -1620,30 +1637,12 @@ func TestHTLCsExtraData(t *testing.T) { }, }, { + // Some HTLCs with extra data, some without. name: "mixed extra data", htlcs: []HTLC{ mockHtlc, - { - Signature: testSig.Serialize(), - Incoming: false, - Amt: 10, - RHash: key, - RefundTimeout: 1, - OnionBlob: lnmock.MockOnion(), - ExtraData: []byte{1, 2, 3}, - }, + blindingPointHTLC, mockHtlc, - { - Signature: testSig.Serialize(), - Incoming: false, - Amt: 10, - RHash: key, - RefundTimeout: 1, - OnionBlob: lnmock.MockOnion(), - ExtraData: bytes.Repeat( - []byte{9}, 999, - ), - }, }, }, } @@ -1661,7 +1660,15 @@ func TestHTLCsExtraData(t *testing.T) { r := bytes.NewReader(b.Bytes()) htlcs, err := DeserializeHtlcs(r) require.NoError(t, err) - require.Equal(t, testCase.htlcs, htlcs) + + require.EqualValues(t, len(testCase.htlcs), len(htlcs)) + for i, htlc := range htlcs { + // We use the extra data field when we + // serialize, so we set to nil to be able to + // assert on equal for the test. + htlc.ExtraData = nil + require.Equal(t, testCase.htlcs[i], htlc) + } }) } } diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 40bdbebb5..1b6e71ffd 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -31,6 +31,7 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -371,6 +372,12 @@ type PaymentDescriptor struct { // isForwarded denotes if an incoming HTLC has been forwarded to any // possible upstream peers in the route. isForwarded bool + + // BlindingPoint is an optional ephemeral key used in route blinding. + // This value is set for nodes that are relaying payments inside of a + // blinded route (ie, not the introduction node) from update_add_htlc's + // TLVs. + BlindingPoint *btcec.PublicKey } // PayDescsFromRemoteLogUpdates converts a slice of LogUpdates received from the @@ -411,6 +418,7 @@ func PayDescsFromRemoteLogUpdates(chanID lnwire.ShortChannelID, height uint64, Height: height, Index: uint16(i), }, + BlindingPoint: wireMsg.BlingingPointOrNil(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -736,6 +744,14 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { Incoming: false, } copy(h.OnionBlob[:], htlc.OnionBlob) + if htlc.BlindingPoint != nil { + h.BlindingPoint = tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + htlc.BlindingPoint, + ), + ) + } if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() @@ -760,7 +776,14 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { Incoming: true, } copy(h.OnionBlob[:], htlc.OnionBlob) - + if htlc.BlindingPoint != nil { + h.BlindingPoint = tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + htlc.BlindingPoint, + ), + ) + } if ourCommit && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -859,6 +882,12 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, theirWitnessScript: theirWitnessScript, } + htlc.BlindingPoint.WhenSome(func(b tlv.RecordT[ + lnwire.BlindingPointTlvType, *btcec.PublicKey]) { + + pd.BlindingPoint = b.Val + }) + return pd, nil } @@ -1548,6 +1577,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, HtlcIndex: wireMsg.ID, LogIndex: logUpdate.LogIndex, addCommitHeightRemote: commitHeight, + BlindingPoint: wireMsg.BlingingPointOrNil(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) @@ -1745,6 +1775,7 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd HtlcIndex: wireMsg.ID, LogIndex: logUpdate.LogIndex, addCommitHeightLocal: commitHeight, + BlindingPoint: wireMsg.BlingingPointOrNil(), } pd.OnionBlob = make([]byte, len(wireMsg.OnionBlob)) copy(pd.OnionBlob, wireMsg.OnionBlob[:]) @@ -3607,6 +3638,14 @@ func (lc *LightningChannel) createCommitDiff( PaymentHash: pd.RHash, } copy(htlc.OnionBlob[:], pd.OnionBlob) + if pd.BlindingPoint != nil { + htlc.BlindingPoint = tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + pd.BlindingPoint, + ), + ) + } logUpdate.UpdateMsg = htlc // Gather any references for circuits opened by this Add @@ -3736,12 +3775,21 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate { // four messages that it corresponds to. switch pd.EntryType { case Add: + var b lnwire.BlindingPointRecord + if pd.BlindingPoint != nil { + tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](pd.BlindingPoint), + ) + } + htlc := &lnwire.UpdateAddHTLC{ - ChanID: chanID, - ID: pd.HtlcIndex, - Amount: pd.Amount, - Expiry: pd.Timeout, - PaymentHash: pd.RHash, + ChanID: chanID, + ID: pd.HtlcIndex, + Amount: pd.Amount, + Expiry: pd.Timeout, + PaymentHash: pd.RHash, + BlindingPoint: b, } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc @@ -5742,6 +5790,14 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( Expiry: pd.Timeout, PaymentHash: pd.RHash, } + if pd.BlindingPoint != nil { + htlc.BlindingPoint = tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType]( + pd.BlindingPoint, + ), + ) + } copy(htlc.OnionBlob[:], pd.OnionBlob) logUpdate.UpdateMsg = htlc addUpdates = append(addUpdates, logUpdate) @@ -6079,6 +6135,7 @@ func (lc *LightningChannel) htlcAddDescriptor(htlc *lnwire.UpdateAddHTLC, HtlcIndex: lc.localUpdateLog.htlcCounter, OnionBlob: htlc.OnionBlob[:], OpenCircuitKey: openKey, + BlindingPoint: htlc.BlingingPointOrNil(), } } @@ -6129,13 +6186,14 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err } pd := &PaymentDescriptor{ - EntryType: Add, - RHash: PaymentHash(htlc.PaymentHash), - Timeout: htlc.Expiry, - Amount: htlc.Amount, - LogIndex: lc.remoteUpdateLog.logIndex, - HtlcIndex: lc.remoteUpdateLog.htlcCounter, - OnionBlob: htlc.OnionBlob[:], + EntryType: Add, + RHash: PaymentHash(htlc.PaymentHash), + Timeout: htlc.Expiry, + Amount: htlc.Amount, + LogIndex: lc.remoteUpdateLog.logIndex, + HtlcIndex: lc.remoteUpdateLog.htlcCounter, + OnionBlob: htlc.OnionBlob[:], + BlindingPoint: htlc.BlingingPointOrNil(), } localACKedIndex := lc.remoteCommitChain.tail().ourMessageIndex diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 7494433e6..d224b4598 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -10419,8 +10420,9 @@ func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC { _, err = rand.Read(sig) require.NoError(t, err) - extra := make([]byte, 1000) - _, err = rand.Read(extra) + blinding, err := pubkeyFromHex( + "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48236c39", //nolint:lll + ) require.NoError(t, err) return channeldb.HTLC{ @@ -10433,7 +10435,10 @@ func createRandomHTLC(t *testing.T, incoming bool) channeldb.HTLC { OnionBlob: onionBlob, HtlcIndex: rand.Uint64(), LogIndex: rand.Uint64(), - ExtraData: extra, + BlindingPoint: tlv.SomeRecordT( + //nolint:lll + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](blinding), + ), } } @@ -11000,3 +11005,61 @@ func TestEnforceFeeBuffer(t *testing.T) { require.Equal(t, aliceBalance, expectedAmt) } + +// TestBlindingPointPersistence tests persistence of blinding points attached +// to htlcs across restarts. +func TestBlindingPointPersistence(t *testing.T) { + // Create a test channel which will be used for the duration of this + // test. The channel will be funded evenly with Alice having 5 BTC, + // and Bob having 5 BTC. + aliceChannel, bobChannel, err := CreateTestChannels( + t, channeldb.SingleFunderTweaklessBit, + ) + require.NoError(t, err, "unable to create test channels") + + // Send a HTLC from Alice to Bob that has a blinding point populated. + htlc, _ := createHTLC(0, 100_000_000) + blinding, err := pubkeyFromHex( + "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d48236c39", //nolint:lll + ) + require.NoError(t, err) + htlc.BlindingPoint = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](blinding), + ) + + _, err = aliceChannel.AddHTLC(htlc, nil) + + require.NoError(t, err) + _, err = bobChannel.ReceiveHTLC(htlc) + require.NoError(t, err) + + // Now, Alice will send a new commitment to Bob, which will persist our + // pending HTLC to disk. + aliceCommit, err := aliceChannel.SignNextCommitment() + require.NoError(t, err, "unable to sign commitment") + + // Restart alice to force fetching state from disk. + aliceChannel, err = restartChannel(aliceChannel) + require.NoError(t, err, "unable to restart alice") + + // Assert that the blinding point is restored from disk. + remoteCommit := aliceChannel.remoteCommitChain.tip() + require.Len(t, remoteCommit.outgoingHTLCs, 1) + require.Equal(t, blinding, remoteCommit.outgoingHTLCs[0].BlindingPoint) + + // Next, update bob's commitment and assert that we can still retrieve + // his incoming blinding point after restart. + err = bobChannel.ReceiveNewCommitment(aliceCommit.CommitSigs) + require.NoError(t, err, "bob unable to receive new commitment") + + _, _, _, err = bobChannel.RevokeCurrentCommitment() + require.NoError(t, err, "bob unable to revoke current commitment") + + bobChannel, err = restartChannel(bobChannel) + require.NoError(t, err, "unable to restart bob's channel") + + // Assert that Bob is able to recover the blinding point from disk. + bobCommit := bobChannel.localCommitChain.tip() + require.Len(t, bobCommit.incomingHTLCs, 1) + require.Equal(t, blinding, bobCommit.incomingHTLCs[0].BlindingPoint) +} diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 8a40710e8..951dc7f54 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -78,6 +78,19 @@ type UpdateAddHTLC struct { ExtraData ExtraOpaqueData } +// BlingingPointOrNil returns the blinding point associated with the update, or +// nil. +func (c *UpdateAddHTLC) BlingingPointOrNil() *btcec.PublicKey { + var blindingPoint *btcec.PublicKey + c.BlindingPoint.WhenSome(func(b tlv.RecordT[BlindingPointTlvType, + *btcec.PublicKey]) { + + blindingPoint = b.Val + }) + + return blindingPoint +} + // NewUpdateAddHTLC returns a new empty UpdateAddHTLC message. func NewUpdateAddHTLC() *UpdateAddHTLC { return &UpdateAddHTLC{}