From 1a5b5c5f62c37c6c68a2cd9ece2a7fabe90c02b0 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 23 Apr 2024 14:10:33 -0700 Subject: [PATCH 1/6] lntypes: Add a ChannelParty type. This commit introduces a ChannelParty type to LND. It is useful for consolidating all references to the duality between the local and remote nodes. This is currently handled by having named struct rows or named boolean parameters, named either "local" or "remote". This change alleviates the programmer from having to decide which node should be bound to `true` or `false`. In an upcoming commit we will change callsites to use this. --- lntypes/channel_party.go | 52 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 lntypes/channel_party.go diff --git a/lntypes/channel_party.go b/lntypes/channel_party.go new file mode 100644 index 000000000..be800541b --- /dev/null +++ b/lntypes/channel_party.go @@ -0,0 +1,52 @@ +package lntypes + +import "fmt" + +// ChannelParty is a type used to have an unambiguous description of which node +// is being referred to. This eliminates the need to describe as "local" or +// "remote" using bool. +type ChannelParty uint8 + +const ( + // Local is a ChannelParty constructor that is used to refer to the + // node that is running. + Local ChannelParty = iota + + // Remote is a ChannelParty constructor that is used to refer to the + // node on the other end of the peer connection. + Remote +) + +// String provides a string representation of ChannelParty (useful for logging). +func (p ChannelParty) String() string { + switch p { + case Local: + return "Local" + case Remote: + return "Remote" + default: + panic(fmt.Sprintf("invalid ChannelParty value: %d", p)) + } +} + +// CounterParty inverts the role of the ChannelParty. +func (p ChannelParty) CounterParty() ChannelParty { + switch p { + case Local: + return Remote + case Remote: + return Local + default: + panic(fmt.Sprintf("invalid ChannelParty value: %v", p)) + } +} + +// IsLocal returns true if the ChannelParty is Local. +func (p ChannelParty) IsLocal() bool { + return p == Local +} + +// IsRemote returns true if the ChannelParty is Remote. +func (p ChannelParty) IsRemote() bool { + return p == Remote +} From 3a1508501473f26fe79f63bd91d3d0c40f1fcd5c Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 30 Jul 2024 16:18:09 -0700 Subject: [PATCH 2/6] input+lnwallet: refactor select methods in input to use ChannelParty --- input/script_utils.go | 10 ++++++---- input/size_test.go | 18 ++++++++++-------- input/taproot_test.go | 4 ++-- lnwallet/commitment.go | 8 ++++---- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/input/script_utils.go b/input/script_utils.go index 80997eed4..104c24251 100644 --- a/input/script_utils.go +++ b/input/script_utils.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "golang.org/x/crypto/ripemd160" ) @@ -789,10 +790,10 @@ func senderHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // unilaterally spend the created output. func SenderHTLCScriptTaproot(senderHtlcKey, receiverHtlcKey, revokeKey *btcec.PublicKey, payHash []byte, - localCommit bool) (*HtlcScriptTree, error) { + whoseCommit lntypes.ChannelParty) (*HtlcScriptTree, error) { var hType htlcType - if localCommit { + if whoseCommit.IsLocal() { hType = htlcLocalOutgoing } else { hType = htlcRemoteIncoming @@ -1348,10 +1349,11 @@ func receiverHtlcTapScriptTree(senderHtlcKey, receiverHtlcKey, // the tap leaf are returned. func ReceiverHTLCScriptTaproot(cltvExpiry uint32, senderHtlcKey, receiverHtlcKey, revocationKey *btcec.PublicKey, - payHash []byte, ourCommit bool) (*HtlcScriptTree, error) { + payHash []byte, whoseCommit lntypes.ChannelParty, +) (*HtlcScriptTree, error) { var hType htlcType - if ourCommit { + if whoseCommit.IsLocal() { hType = htlcLocalIncoming } else { hType = htlcRemoteOutgoing diff --git a/input/size_test.go b/input/size_test.go index 9c3446afb..daa7053cc 100644 --- a/input/size_test.go +++ b/input/size_test.go @@ -13,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" ) @@ -1073,7 +1074,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1115,7 +1116,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1157,7 +1158,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1203,7 +1204,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1263,7 +1264,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1309,7 +1310,7 @@ var witnessSizeTests = []witnessSizeTest{ htlcScriptTree, err := input.ReceiverHTLCScriptTaproot( testCLTVExpiry, senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -1394,7 +1395,8 @@ func genTimeoutTx(t *testing.T, ) if chanType.IsTaproot() { tapscriptTree, err = input.SenderHTLCScriptTaproot( - testPubkey, testPubkey, testPubkey, testHash160, false, + testPubkey, testPubkey, testPubkey, testHash160, + lntypes.Remote, ) require.NoError(t, err) @@ -1463,7 +1465,7 @@ func genSuccessTx(t *testing.T, chanType channeldb.ChannelType) *wire.MsgTx { if chanType.IsTaproot() { tapscriptTree, err = input.ReceiverHTLCScriptTaproot( testCLTVExpiry, testPubkey, testPubkey, testPubkey, - testHash160, false, + testHash160, lntypes.Remote, ) require.NoError(t, err) diff --git a/input/taproot_test.go b/input/taproot_test.go index 801b0fef4..434be2dfd 100644 --- a/input/taproot_test.go +++ b/input/taproot_test.go @@ -48,7 +48,7 @@ func newTestSenderHtlcScriptTree(t *testing.T) *testSenderHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := SenderHTLCScriptTaproot( senderKey.PubKey(), receiverKey.PubKey(), revokeKey.PubKey(), - payHash[:], false, + payHash[:], lntypes.Remote, ) require.NoError(t, err) @@ -471,7 +471,7 @@ func newTestReceiverHtlcScriptTree(t *testing.T) *testReceiverHtlcScriptTree { payHash := preImage.Hash() htlcScriptTree, err := ReceiverHTLCScriptTaproot( cltvExpiry, senderKey.PubKey(), receiverKey.PubKey(), - revokeKey.PubKey(), payHash[:], false, + revokeKey.PubKey(), payHash[:], lntypes.Remote, ) require.NoError(t, err) diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 96af8d7cf..1e1140fbc 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -1095,7 +1095,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case isIncoming && ourCommit: htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], lntypes.Local, ) // We're being paid via an HTLC by the remote party, and the HTLC is @@ -1104,7 +1104,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case isIncoming && !ourCommit: htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], lntypes.Remote, ) // We're sending an HTLC which is being added to our commitment @@ -1113,7 +1113,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case !isIncoming && ourCommit: htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], lntypes.Local, ) // Finally, we're paying the remote party via an HTLC, which is being @@ -1122,7 +1122,7 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, case !isIncoming && !ourCommit: htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], ourCommit, + keyRing.RevocationKey, rHash[:], lntypes.Remote, ) } From 33934449ac8b99f0601a35311f3ea3bdfdaf4634 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 30 Jul 2024 16:25:40 -0700 Subject: [PATCH 3/6] multi: refactor select methods within channeldb to use ChannelParty Also in this commit is a small adjustment to the call-sites to get the boundaries stitched back together. --- channeldb/channel.go | 22 ++++++++++++++++------ channeldb/channel_test.go | 19 +++++++++++++------ channeldb/db_test.go | 9 +++++++-- contractcourt/chain_arbitrator_test.go | 7 +++++-- contractcourt/channel_arbitrator.go | 4 ++-- contractcourt/channel_arbitrator_test.go | 4 +++- lnwallet/channel.go | 14 ++++++++++++-- 7 files changed, 58 insertions(+), 21 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index 046ef8806..ad0208467 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/tlv" @@ -1690,11 +1691,11 @@ func (c *OpenChannel) isBorked(chanBucket kvdb.RBucket) (bool, error) { // republish this tx at startup to ensure propagation, and we should still // handle the case where a different tx actually hits the chain. func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx, - locallyInitiated bool) error { + closer lntypes.ChannelParty) error { return c.markBroadcasted( ChanStatusCommitBroadcasted, forceCloseTxKey, closeTx, - locallyInitiated, + closer, ) } @@ -1706,11 +1707,11 @@ func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx, // ensure propagation, and we should still handle the case where a different tx // actually hits the chain. func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx, - locallyInitiated bool) error { + closer lntypes.ChannelParty) error { return c.markBroadcasted( ChanStatusCoopBroadcasted, coopCloseTxKey, closeTx, - locallyInitiated, + closer, ) } @@ -1719,7 +1720,7 @@ func (c *OpenChannel) MarkCoopBroadcasted(closeTx *wire.MsgTx, // which should specify either a coop or force close. It adds a status which // indicates the party that initiated the channel close. func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte, - closeTx *wire.MsgTx, locallyInitiated bool) error { + closeTx *wire.MsgTx, closer lntypes.ChannelParty) error { c.Lock() defer c.Unlock() @@ -1741,7 +1742,7 @@ func (c *OpenChannel) markBroadcasted(status ChannelStatus, key []byte, // Add the initiator status to the status provided. These statuses are // set in addition to the broadcast status so that we do not need to // migrate the original logic which does not store initiator. - if locallyInitiated { + if closer.IsLocal() { status |= ChanStatusLocalCloseInitiator } else { status |= ChanStatusRemoteCloseInitiator @@ -4486,6 +4487,15 @@ func NewShutdownInfo(deliveryScript lnwire.DeliveryAddress, } } +// Closer identifies the ChannelParty that initiated the coop-closure process. +func (s ShutdownInfo) Closer() lntypes.ChannelParty { + if s.LocalInitiator.Val { + return lntypes.Local + } + + return lntypes.Remote +} + // encode serialises the ShutdownInfo to the given io.Writer. func (s *ShutdownInfo) encode(w io.Writer) error { records := []tlv.Record{ diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 981ddf688..e630b1c48 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -21,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lntest/channels" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/tlv" @@ -1084,13 +1085,17 @@ func TestFetchWaitingCloseChannels(t *testing.T) { }, ) - if err := channel.MarkCommitmentBroadcasted(closeTx, true); err != nil { + if err := channel.MarkCommitmentBroadcasted( + closeTx, lntypes.Local, + ); err != nil { t.Fatalf("unable to mark commitment broadcast: %v", err) } // Now try to marking a coop close with a nil tx. This should // succeed, but it shouldn't exit when queried. - if err = channel.MarkCoopBroadcasted(nil, true); err != nil { + if err = channel.MarkCoopBroadcasted( + nil, lntypes.Local, + ); err != nil { t.Fatalf("unable to mark nil coop broadcast: %v", err) } _, err := channel.BroadcastedCooperative() @@ -1102,7 +1107,9 @@ func TestFetchWaitingCloseChannels(t *testing.T) { // it as coop closed. Later we will test that distinct // transactions are returned for both coop and force closes. closeTx.TxIn[0].PreviousOutPoint.Index ^= 1 - if err := channel.MarkCoopBroadcasted(closeTx, true); err != nil { + if err := channel.MarkCoopBroadcasted( + closeTx, lntypes.Local, + ); err != nil { t.Fatalf("unable to mark coop broadcast: %v", err) } } @@ -1324,7 +1331,7 @@ func TestCloseInitiator(t *testing.T) { // by the local party. updateChannel: func(c *OpenChannel) error { return c.MarkCoopBroadcasted( - &wire.MsgTx{}, true, + &wire.MsgTx{}, lntypes.Local, ) }, expectedStatuses: []ChannelStatus{ @@ -1338,7 +1345,7 @@ func TestCloseInitiator(t *testing.T) { // by the remote party. updateChannel: func(c *OpenChannel) error { return c.MarkCoopBroadcasted( - &wire.MsgTx{}, false, + &wire.MsgTx{}, lntypes.Remote, ) }, expectedStatuses: []ChannelStatus{ @@ -1352,7 +1359,7 @@ func TestCloseInitiator(t *testing.T) { // local initiator. updateChannel: func(c *OpenChannel) error { return c.MarkCommitmentBroadcasted( - &wire.MsgTx{}, true, + &wire.MsgTx{}, lntypes.Local, ) }, expectedStatuses: []ChannelStatus{ diff --git a/channeldb/db_test.go b/channeldb/db_test.go index a954f2828..025bf1261 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" "github.com/stretchr/testify/require" @@ -606,7 +607,9 @@ func TestFetchChannels(t *testing.T) { channelIDOption(pendingWaitingChan), ) - err = pendingClosing.MarkCoopBroadcasted(nil, true) + err = pendingClosing.MarkCoopBroadcasted( + nil, lntypes.Local, + ) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -626,7 +629,9 @@ func TestFetchChannels(t *testing.T) { channelIDOption(openWaitingChan), openChannelOption(), ) - err = openClosing.MarkCoopBroadcasted(nil, true) + err = openClosing.MarkCoopBroadcasted( + nil, lntypes.Local, + ) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index 36f6dad18..abaca5c2b 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" ) @@ -61,12 +62,14 @@ func TestChainArbitratorRepublishCloses(t *testing.T) { for i := 0; i < numChans/2; i++ { closeTx := channels[i].FundingTxn.Copy() closeTx.TxIn[0].PreviousOutPoint = channels[i].FundingOutpoint - err := channels[i].MarkCommitmentBroadcasted(closeTx, true) + err := channels[i].MarkCommitmentBroadcasted( + closeTx, lntypes.Local, + ) if err != nil { t.Fatal(err) } - err = channels[i].MarkCoopBroadcasted(closeTx, true) + err = channels[i].MarkCoopBroadcasted(closeTx, lntypes.Local) if err != nil { t.Fatal(err) } diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 8add61ce6..cb5cee872 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -129,7 +129,7 @@ type ChannelArbitratorConfig struct { // MarkCommitmentBroadcasted should mark the channel as the commitment // being broadcast, and we are waiting for the commitment to confirm. - MarkCommitmentBroadcasted func(*wire.MsgTx, bool) error + MarkCommitmentBroadcasted func(*wire.MsgTx, lntypes.ChannelParty) error // MarkChannelClosed marks the channel closed in the database, with the // passed close summary. After this method successfully returns we can @@ -1084,7 +1084,7 @@ func (c *ChannelArbitrator) stateStep( // database, such that we can re-publish later in case it // didn't propagate. We initiated the force close, so we // mark broadcast with local initiator set to true. - err = c.cfg.MarkCommitmentBroadcasted(closeTx, true) + err = c.cfg.MarkCommitmentBroadcasted(closeTx, lntypes.Local) if err != nil { log.Errorf("ChannelArbitrator(%v): unable to "+ "mark commitment broadcasted: %v", diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 43238494e..916cd5f58 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -416,7 +416,9 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, resolvedChan <- struct{}{} return nil }, - MarkCommitmentBroadcasted: func(_ *wire.MsgTx, _ bool) error { + MarkCommitmentBroadcasted: func(_ *wire.MsgTx, + _ lntypes.ChannelParty) error { + return nil }, MarkChannelClosed: func(*channeldb.ChannelCloseSummary, diff --git a/lnwallet/channel.go b/lnwallet/channel.go index abe8b6276..3e2e8cf39 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -8460,7 +8460,12 @@ func (lc *LightningChannel) MarkCommitmentBroadcasted(tx *wire.MsgTx, lc.Lock() defer lc.Unlock() - return lc.channelState.MarkCommitmentBroadcasted(tx, locallyInitiated) + party := lntypes.Remote + if locallyInitiated { + party = lntypes.Local + } + + return lc.channelState.MarkCommitmentBroadcasted(tx, party) } // MarkCoopBroadcasted marks the channel as a cooperative close transaction has @@ -8473,7 +8478,12 @@ func (lc *LightningChannel) MarkCoopBroadcasted(tx *wire.MsgTx, lc.Lock() defer lc.Unlock() - return lc.channelState.MarkCoopBroadcasted(tx, localInitiated) + party := lntypes.Remote + if localInitiated { + party = lntypes.Local + } + + return lc.channelState.MarkCoopBroadcasted(tx, party) } // MarkShutdownSent persists the given ShutdownInfo. The existence of the From 0996e4f1637bbc6bd6e00802b0381d2c7b424148 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 30 Jul 2024 16:44:18 -0700 Subject: [PATCH 4/6] multi: refactor lnwallet/channel.go to use ChannelParty in select places We also include changes to contractcourt, htlcswitch and peer to stitch the boundaries together. --- contractcourt/chain_watcher.go | 5 +- htlcswitch/link.go | 11 +- lnwallet/chancloser/chancloser.go | 16 +- lnwallet/chancloser/chancloser_test.go | 13 +- lnwallet/chancloser/interface.go | 3 +- lnwallet/channel.go | 376 +++++++++++++------------ lnwallet/channel_test.go | 254 ++++++++--------- lnwallet/commitment.go | 78 ++--- lnwallet/wallet.go | 7 +- peer/brontide.go | 7 +- 10 files changed, 409 insertions(+), 361 deletions(-) diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 962a239e7..3cbc7422d 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -20,6 +20,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -418,7 +419,7 @@ func (c *chainWatcher) handleUnknownLocalState( // and remote keys for this state. We use our point as only we can // revoke our own commitment. commitKeyRing := lnwallet.DeriveCommitmentKeys( - commitPoint, true, c.cfg.chanState.ChanType, + commitPoint, lntypes.Local, c.cfg.chanState.ChanType, &c.cfg.chanState.LocalChanCfg, &c.cfg.chanState.RemoteChanCfg, ) @@ -891,7 +892,7 @@ func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail, // Create an AnchorResolution for the breached state. anchorRes, err := lnwallet.NewAnchorResolution( c.cfg.chanState, commitSpend.SpendingTx, retribution.KeyRing, - false, + lntypes.Remote, ) if err != nil { return false, fmt.Errorf("unable to create anchor "+ diff --git a/htlcswitch/link.go b/htlcswitch/link.go index eee18ff59..99df2c2b9 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2730,7 +2730,12 @@ func (l *channelLink) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error { func (l *channelLink) getDustSum(remote bool, dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { - return l.channel.GetDustSum(remote, dryRunFee) + party := lntypes.Local + if remote { + party = lntypes.Remote + } + + return l.channel.GetDustSum(party, dryRunFee) } // getFeeRate is a wrapper method that retrieves the underlying channel's @@ -2893,13 +2898,13 @@ func dustHelper(chantype channeldb.ChannelType, localDustLimit, if localCommit { return lnwallet.HtlcIsDust( - chantype, incoming, true, feerate, amt, + chantype, incoming, lntypes.Local, feerate, amt, localDustLimit, ) } return lnwallet.HtlcIsDust( - chantype, incoming, false, feerate, amt, + chantype, incoming, lntypes.Remote, feerate, amt, remoteDustLimit, ) } diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go index 3f5e730c0..57033d4b3 100644 --- a/lnwallet/chancloser/chancloser.go +++ b/lnwallet/chancloser/chancloser.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/labels" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -207,8 +208,8 @@ type ChanCloser struct { // settled channel funds to. remoteDeliveryScript []byte - // locallyInitiated is true if we initiated the channel close. - locallyInitiated bool + // closer is ChannelParty who initiated the coop close + closer lntypes.ChannelParty // cachedClosingSigned is a cached copy of a received ClosingSigned that // we use to handle a specific race condition caused by the independent @@ -267,7 +268,8 @@ func (d *SimpleCoopFeeEstimator) EstimateFee(chanType channeldb.ChannelType, // be populated iff, we're the initiator of this closing request. func NewChanCloser(cfg ChanCloseCfg, deliveryScript []byte, idealFeePerKw chainfee.SatPerKWeight, negotiationHeight uint32, - closeReq *htlcswitch.ChanClose, locallyInitiated bool) *ChanCloser { + closeReq *htlcswitch.ChanClose, + closer lntypes.ChannelParty) *ChanCloser { chanPoint := cfg.Channel.ChannelPoint() cid := lnwire.NewChanIDFromOutPoint(chanPoint) @@ -283,7 +285,7 @@ func NewChanCloser(cfg ChanCloseCfg, deliveryScript []byte, priorFeeOffers: make( map[btcutil.Amount]*lnwire.ClosingSigned, ), - locallyInitiated: locallyInitiated, + closer: closer, } } @@ -366,7 +368,7 @@ func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) { // message we are about to send in order to ensure that if a // re-establish occurs then we will re-send the same Shutdown message. shutdownInfo := channeldb.NewShutdownInfo( - c.localDeliveryScript, c.locallyInitiated, + c.localDeliveryScript, c.closer.IsLocal(), ) err := c.cfg.Channel.MarkShutdownSent(shutdownInfo) if err != nil { @@ -650,7 +652,7 @@ func (c *ChanCloser) BeginNegotiation() (fn.Option[lnwire.ClosingSigned], // externally consistent, and reflect that the channel is being // shutdown by the time the closing request returns. err := c.cfg.Channel.MarkCoopBroadcasted( - nil, c.locallyInitiated, + nil, c.closer, ) if err != nil { return noClosingSigned, err @@ -861,7 +863,7 @@ func (c *ChanCloser) ReceiveClosingSigned( //nolint:funlen // database, such that it can be republished if something goes // wrong. err = c.cfg.Channel.MarkCoopBroadcasted( - closeTx, c.locallyInitiated, + closeTx, c.closer, ) if err != nil { return noClosing, err diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 1956f0d2b..9a90d0ab2 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -150,7 +151,9 @@ func (m *mockChannel) ChannelPoint() wire.OutPoint { return m.chanPoint } -func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx, bool) error { +func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx, + lntypes.ChannelParty) error { + return nil } @@ -338,7 +341,7 @@ func TestMaxFeeClamp(t *testing.T) { Channel: &channel, MaxFee: test.inputMaxFee, FeeEstimator: &SimpleCoopFeeEstimator{}, - }, nil, test.idealFee, 0, nil, false, + }, nil, test.idealFee, 0, nil, lntypes.Remote, ) // We'll call initFeeBaseline early here since we need @@ -379,7 +382,7 @@ func TestMaxFeeBailOut(t *testing.T) { MaxFee: idealFee * 2, } chanCloser := NewChanCloser( - closeCfg, nil, idealFee, 0, nil, false, + closeCfg, nil, idealFee, 0, nil, lntypes.Remote, ) // We'll now force the channel state into the @@ -503,7 +506,7 @@ func TestTaprootFastClose(t *testing.T) { DisableChannel: func(wire.OutPoint) error { return nil }, - }, nil, idealFee, 0, nil, true, + }, nil, idealFee, 0, nil, lntypes.Local, ) aliceCloser.initFeeBaseline() @@ -520,7 +523,7 @@ func TestTaprootFastClose(t *testing.T) { DisableChannel: func(wire.OutPoint) error { return nil }, - }, nil, idealFee, 0, nil, false, + }, nil, idealFee, 0, nil, lntypes.Remote, ) bobCloser.initFeeBaseline() diff --git a/lnwallet/chancloser/interface.go b/lnwallet/chancloser/interface.go index 40b81efb4..2e9fa98ae 100644 --- a/lnwallet/chancloser/interface.go +++ b/lnwallet/chancloser/interface.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" @@ -33,7 +34,7 @@ type Channel interface { //nolint:interfacebloat // MarkCoopBroadcasted persistently marks that the channel close // transaction has been broadcast. - MarkCoopBroadcasted(*wire.MsgTx, bool) error + MarkCoopBroadcasted(*wire.MsgTx, lntypes.ChannelParty) error // MarkShutdownSent persists the given ShutdownInfo. The existence of // the ShutdownInfo represents the fact that the Shutdown message has diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 3e2e8cf39..afe10b950 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -271,9 +271,9 @@ type commitment struct { // update number of this commitment. height uint64 - // isOurs indicates whether this is the local or remote node's version - // of the commitment. - isOurs bool + // whoseCommit indicates whether this is the local or remote node's + // version of the commitment. + whoseCommit lntypes.ChannelParty // [our|their]MessageIndex are indexes into the HTLC log, up to which // this commitment transaction includes. These indexes allow both sides @@ -352,8 +352,9 @@ type commitment struct { // massed in is to be retained for each output within the commitment // transition. This ensures that we don't assign multiple HTLCs to the same // index within the commitment transaction. -func locateOutputIndex(p *PaymentDescriptor, tx *wire.MsgTx, ourCommit bool, - dups map[PaymentHash][]int32, cltvs []uint32) (int32, error) { +func locateOutputIndex(p *PaymentDescriptor, tx *wire.MsgTx, + whoseCommit lntypes.ChannelParty, dups map[PaymentHash][]int32, + cltvs []uint32) (int32, error) { // Checks to see if element (e) exists in slice (s). contains := func(s []int32, e int32) bool { @@ -370,7 +371,7 @@ func locateOutputIndex(p *PaymentDescriptor, tx *wire.MsgTx, ourCommit bool, // required as the commitment states are asymmetric in order to ascribe // blame in the case of a contract breach. pkScript := p.theirPkScript - if ourCommit { + if whoseCommit.IsLocal() { pkScript = p.ourPkScript } @@ -418,7 +419,7 @@ func (c *commitment) populateHtlcIndexes(chanType channeldb.ChannelType, // indexes within the commitment view for a particular HTLC. populateIndex := func(htlc *PaymentDescriptor, incoming bool) error { isDust := HtlcIsDust( - chanType, incoming, c.isOurs, c.feePerKw, + chanType, incoming, c.whoseCommit, c.feePerKw, htlc.Amount.ToSatoshis(), c.dustLimit, ) @@ -427,21 +428,21 @@ func (c *commitment) populateHtlcIndexes(chanType channeldb.ChannelType, // If this is our commitment transaction, and this is a dust // output then we mark it as such using a -1 index. - case c.isOurs && isDust: + case c.whoseCommit.IsLocal() && isDust: htlc.localOutputIndex = -1 // If this is the commitment transaction of the remote party, // and this is a dust output then we mark it as such using a -1 // index. - case !c.isOurs && isDust: + case c.whoseCommit.IsRemote() && isDust: htlc.remoteOutputIndex = -1 // If this is our commitment transaction, then we'll need to // locate the output and the index so we can verify an HTLC // signatures. - case c.isOurs: + case c.whoseCommit.IsLocal(): htlc.localOutputIndex, err = locateOutputIndex( - htlc, c.txn, c.isOurs, dups, cltvs, + htlc, c.txn, c.whoseCommit, dups, cltvs, ) if err != nil { return err @@ -460,9 +461,9 @@ func (c *commitment) populateHtlcIndexes(chanType channeldb.ChannelType, // Otherwise, this is there remote party's commitment // transaction and we only need to populate the remote output // index within the HTLC index. - case !c.isOurs: + case c.whoseCommit.IsRemote(): htlc.remoteOutputIndex, err = locateOutputIndex( - htlc, c.txn, c.isOurs, dups, cltvs, + htlc, c.txn, c.whoseCommit, dups, cltvs, ) if err != nil { return err @@ -497,7 +498,9 @@ func (c *commitment) populateHtlcIndexes(chanType channeldb.ChannelType, // toDiskCommit converts the target commitment into a format suitable to be // written to disk after an accepted state transition. -func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { +func (c *commitment) toDiskCommit( + whoseCommit lntypes.ChannelParty) *channeldb.ChannelCommitment { + numHtlcs := len(c.outgoingHTLCs) + len(c.incomingHTLCs) commit := &channeldb.ChannelCommitment{ @@ -517,7 +520,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { for _, htlc := range c.outgoingHTLCs { outputIndex := htlc.localOutputIndex - if !ourCommit { + if whoseCommit.IsRemote() { outputIndex = htlc.remoteOutputIndex } @@ -533,7 +536,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { } copy(h.OnionBlob[:], htlc.OnionBlob) - if ourCommit && htlc.sig != nil { + if whoseCommit.IsLocal() && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -542,7 +545,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { for _, htlc := range c.incomingHTLCs { outputIndex := htlc.localOutputIndex - if !ourCommit { + if whoseCommit.IsRemote() { outputIndex = htlc.remoteOutputIndex } @@ -557,7 +560,7 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { BlindingPoint: htlc.BlindingPoint, } copy(h.OnionBlob[:], htlc.OnionBlob) - if ourCommit && htlc.sig != nil { + if whoseCommit.IsLocal() && htlc.sig != nil { h.Signature = htlc.sig.Serialize() } @@ -574,8 +577,8 @@ func (c *commitment) toDiskCommit(ourCommit bool) *channeldb.ChannelCommitment { // restart a channel session. func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, commitHeight uint64, htlc *channeldb.HTLC, localCommitKeys, - remoteCommitKeys *CommitmentKeyRing, isLocal bool) (PaymentDescriptor, - error) { + remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, +) (PaymentDescriptor, error) { // The proper pkScripts for this PaymentDescriptor must be // generated so we can easily locate them within the commitment @@ -593,13 +596,13 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // transaction. As we'll mark dust with a special output index in the // on-disk state snapshot. isDustLocal := HtlcIsDust( - chanType, htlc.Incoming, true, feeRate, + chanType, htlc.Incoming, lntypes.Local, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.LocalChanCfg.DustLimit, ) if !isDustLocal && localCommitKeys != nil { scriptInfo, err := genHtlcScript( - chanType, htlc.Incoming, true, htlc.RefundTimeout, - htlc.RHash, localCommitKeys, + chanType, htlc.Incoming, lntypes.Local, + htlc.RefundTimeout, htlc.RHash, localCommitKeys, ) if err != nil { return pd, err @@ -608,13 +611,13 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, ourWitnessScript = scriptInfo.WitnessScriptToSign() } isDustRemote := HtlcIsDust( - chanType, htlc.Incoming, false, feeRate, + chanType, htlc.Incoming, lntypes.Remote, feeRate, htlc.Amt.ToSatoshis(), lc.channelState.RemoteChanCfg.DustLimit, ) if !isDustRemote && remoteCommitKeys != nil { scriptInfo, err := genHtlcScript( - chanType, htlc.Incoming, false, htlc.RefundTimeout, - htlc.RHash, remoteCommitKeys, + chanType, htlc.Incoming, lntypes.Remote, + htlc.RefundTimeout, htlc.RHash, remoteCommitKeys, ) if err != nil { return pd, err @@ -630,7 +633,7 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, localOutputIndex int32 remoteOutputIndex int32 ) - if isLocal { + if whoseCommit.IsLocal() { localOutputIndex = htlc.OutputIndex } else { remoteOutputIndex = htlc.OutputIndex @@ -663,8 +666,8 @@ func (lc *LightningChannel) diskHtlcToPayDesc(feeRate chainfee.SatPerKWeight, // for each side. func (lc *LightningChannel) extractPayDescs(commitHeight uint64, feeRate chainfee.SatPerKWeight, htlcs []channeldb.HTLC, localCommitKeys, - remoteCommitKeys *CommitmentKeyRing, isLocal bool) ([]PaymentDescriptor, - []PaymentDescriptor, error) { + remoteCommitKeys *CommitmentKeyRing, whoseCommit lntypes.ChannelParty, +) ([]PaymentDescriptor, []PaymentDescriptor, error) { var ( incomingHtlcs []PaymentDescriptor @@ -684,7 +687,7 @@ func (lc *LightningChannel) extractPayDescs(commitHeight uint64, payDesc, err := lc.diskHtlcToPayDesc( feeRate, commitHeight, &htlc, localCommitKeys, remoteCommitKeys, - isLocal, + whoseCommit, ) if err != nil { return incomingHtlcs, outgoingHtlcs, err @@ -703,7 +706,8 @@ func (lc *LightningChannel) extractPayDescs(commitHeight uint64, // diskCommitToMemCommit converts the on-disk commitment format to our // in-memory commitment format which is needed in order to properly resume // channel operations after a restart. -func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, +func (lc *LightningChannel) diskCommitToMemCommit( + whoseCommit lntypes.ChannelParty, diskCommit *channeldb.ChannelCommitment, localCommitPoint, remoteCommitPoint *btcec.PublicKey) (*commitment, error) { @@ -715,14 +719,16 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, var localCommitKeys, remoteCommitKeys *CommitmentKeyRing if localCommitPoint != nil { localCommitKeys = DeriveCommitmentKeys( - localCommitPoint, true, lc.channelState.ChanType, + localCommitPoint, lntypes.Local, + lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) } if remoteCommitPoint != nil { remoteCommitKeys = DeriveCommitmentKeys( - remoteCommitPoint, false, lc.channelState.ChanType, + remoteCommitPoint, lntypes.Remote, + lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) @@ -735,7 +741,7 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, diskCommit.CommitHeight, chainfee.SatPerKWeight(diskCommit.FeePerKw), diskCommit.Htlcs, localCommitKeys, remoteCommitKeys, - isLocal, + whoseCommit, ) if err != nil { return nil, err @@ -745,7 +751,7 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, // commitment state as it was originally present in memory. commit := &commitment{ height: diskCommit.CommitHeight, - isOurs: isLocal, + whoseCommit: whoseCommit, ourBalance: diskCommit.LocalBalance, theirBalance: diskCommit.RemoteBalance, ourMessageIndex: diskCommit.LocalLogIndex, @@ -759,7 +765,7 @@ func (lc *LightningChannel) diskCommitToMemCommit(isLocal bool, incomingHTLCs: incomingHtlcs, outgoingHTLCs: outgoingHtlcs, } - if isLocal { + if whoseCommit.IsLocal() { commit.dustLimit = lc.channelState.LocalChanCfg.DustLimit } else { commit.dustLimit = lc.channelState.RemoteChanCfg.DustLimit @@ -1102,12 +1108,12 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, copy(pd.OnionBlob[:], wireMsg.OnionBlob[:]) isDustRemote := HtlcIsDust( - lc.channelState.ChanType, false, false, feeRate, - wireMsg.Amount.ToSatoshis(), remoteDustLimit, + lc.channelState.ChanType, false, lntypes.Remote, + feeRate, wireMsg.Amount.ToSatoshis(), remoteDustLimit, ) if !isDustRemote { scriptInfo, err := genHtlcScript( - lc.channelState.ChanType, false, false, + lc.channelState.ChanType, false, lntypes.Remote, wireMsg.Expiry, wireMsg.PaymentHash, remoteCommitKeys, ) @@ -1400,7 +1406,7 @@ func (lc *LightningChannel) restoreCommitState( // commitment into our in-memory commitment format, inserting it into // the local commitment chain. localCommit, err := lc.diskCommitToMemCommit( - true, localCommitState, localCommitPoint, + lntypes.Local, localCommitState, localCommitPoint, remoteCommitPoint, ) if err != nil { @@ -1413,7 +1419,7 @@ func (lc *LightningChannel) restoreCommitState( // We'll also do the same for the remote commitment chain. remoteCommit, err := lc.diskCommitToMemCommit( - false, remoteCommitState, localCommitPoint, + lntypes.Remote, remoteCommitState, localCommitPoint, remoteCommitPoint, ) if err != nil { @@ -1445,7 +1451,7 @@ func (lc *LightningChannel) restoreCommitState( // corresponding state for the local commitment chain. pendingCommitPoint := lc.channelState.RemoteNextRevocation pendingRemoteCommit, err = lc.diskCommitToMemCommit( - false, &pendingRemoteCommitDiff.Commitment, + lntypes.Remote, &pendingRemoteCommitDiff.Commitment, nil, pendingCommitPoint, ) if err != nil { @@ -1459,8 +1465,10 @@ func (lc *LightningChannel) restoreCommitState( // We'll also re-create the set of commitment keys needed to // fully re-derive the state. pendingRemoteKeyChain = DeriveCommitmentKeys( - pendingCommitPoint, false, lc.channelState.ChanType, - &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, + pendingCommitPoint, lntypes.Remote, + lc.channelState.ChanType, + &lc.channelState.LocalChanCfg, + &lc.channelState.RemoteChanCfg, ) } @@ -1971,7 +1979,7 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, // With the commitment point generated, we can now generate the four // keys we'll need to reconstruct the commitment state, keyRing := DeriveCommitmentKeys( - commitmentPoint, false, chanState.ChanType, + commitmentPoint, lntypes.Remote, chanState.ChanType, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) @@ -2174,7 +2182,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. scriptInfo, err := genHtlcScript( - chanState.ChanType, htlc.Incoming, false, + chanState.ChanType, htlc.Incoming, lntypes.Remote, htlc.RefundTimeout, htlc.RHash, keyRing, ) if err != nil { @@ -2377,7 +2385,7 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, // If the HTLC is dust, then we'll skip it as it doesn't have // an output on the commitment transaction. if HtlcIsDust( - chanState.ChanType, htlc.Incoming, false, + chanState.ChanType, htlc.Incoming, lntypes.Remote, chainfee.SatPerKWeight(revokedLog.FeePerKw), htlc.Amt.ToSatoshis(), chanState.RemoteChanCfg.DustLimit, @@ -2424,8 +2432,9 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, // covenants. Depending on the two bits, we'll either be using a timeout or // success transaction which have different weights. func HtlcIsDust(chanType channeldb.ChannelType, - incoming, ourCommit bool, feePerKw chainfee.SatPerKWeight, - htlcAmt, dustLimit btcutil.Amount) bool { + incoming bool, whoseCommit lntypes.ChannelParty, + feePerKw chainfee.SatPerKWeight, htlcAmt, dustLimit btcutil.Amount, +) bool { // First we'll determine the fee required for this HTLC based on if this is // an incoming HTLC or not, and also on whose commitment transaction it @@ -2435,25 +2444,25 @@ func HtlcIsDust(chanType channeldb.ChannelType, // If this is an incoming HTLC on our commitment transaction, then the // second-level transaction will be a success transaction. - case incoming && ourCommit: + case incoming && whoseCommit.IsLocal(): htlcFee = HtlcSuccessFee(chanType, feePerKw) // If this is an incoming HTLC on their commitment transaction, then // we'll be using a second-level timeout transaction as they've added // this HTLC. - case incoming && !ourCommit: + case incoming && whoseCommit.IsRemote(): htlcFee = HtlcTimeoutFee(chanType, feePerKw) // If this is an outgoing HTLC on our commitment transaction, then // we'll be using a timeout transaction as we're the sender of the // HTLC. - case !incoming && ourCommit: + case !incoming && whoseCommit.IsLocal(): htlcFee = HtlcTimeoutFee(chanType, feePerKw) // If this is an outgoing HTLC on their commitment transaction, then // we'll be using an HTLC success transaction as they're the receiver // of this HTLC. - case !incoming && !ourCommit: + case !incoming && whoseCommit.IsRemote(): htlcFee = HtlcSuccessFee(chanType, feePerKw) } @@ -2508,13 +2517,14 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *ht // both local and remote commitment transactions in order to sign or verify new // commitment updates. A fully populated commitment is returned which reflects // the proper balances for both sides at this point in the commitment chain. -func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, +func (lc *LightningChannel) fetchCommitmentView( + whoseCommitChain lntypes.ChannelParty, ourLogIndex, ourHtlcIndex, theirLogIndex, theirHtlcIndex uint64, keyRing *CommitmentKeyRing) (*commitment, error) { commitChain := lc.localCommitChain dustLimit := lc.channelState.LocalChanCfg.DustLimit - if remoteChain { + if whoseCommitChain.IsRemote() { commitChain = lc.remoteCommitChain dustLimit = lc.channelState.RemoteChanCfg.DustLimit } @@ -2528,7 +2538,8 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, // initiator. htlcView := lc.fetchHTLCView(theirLogIndex, ourLogIndex) ourBalance, theirBalance, _, filteredHTLCView, err := lc.computeView( - htlcView, remoteChain, true, fn.None[chainfee.SatPerKWeight](), + htlcView, whoseCommitChain, true, + fn.None[chainfee.SatPerKWeight](), ) if err != nil { return nil, err @@ -2537,8 +2548,8 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, // Actually generate unsigned commitment transaction for this view. commitTx, err := lc.commitBuilder.createUnsignedCommitmentTx( - ourBalance, theirBalance, !remoteChain, feePerKw, nextHeight, - filteredHTLCView, keyRing, + ourBalance, theirBalance, whoseCommitChain, feePerKw, + nextHeight, filteredHTLCView, keyRing, ) if err != nil { return nil, err @@ -2587,7 +2598,7 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, height: nextHeight, feePerKw: feePerKw, dustLimit: dustLimit, - isOurs: !remoteChain, + whoseCommit: whoseCommitChain, } // In order to ensure _none_ of the HTLC's associated with this new @@ -2635,7 +2646,8 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { // method. func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - remoteChain, mutateState bool) (*htlcView, error) { + whoseCommitChain lntypes.ChannelParty, mutateState bool, +) (*htlcView, error) { // We initialize the view's fee rate to the fee rate of the unfiltered // view. If any fee updates are found when evaluating the view, it will @@ -2663,8 +2675,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // Process fee updates, updating the current feePerKw. case FeeUpdate: processFeeUpdate( - entry, nextHeight, remoteChain, mutateState, - newView, + entry, nextHeight, whoseCommitChain, + mutateState, newView, ) continue } @@ -2672,19 +2684,22 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // If we're settling an inbound HTLC, and it hasn't been // processed yet, then increment our state tracking the total // number of satoshis we've received within the channel. - if mutateState && entry.EntryType == Settle && !remoteChain && + if mutateState && entry.EntryType == Settle && + whoseCommitChain.IsLocal() && entry.removeCommitHeightLocal == 0 { lc.channelState.TotalMSatReceived += entry.Amount } - addEntry, err := lc.fetchParent(entry, remoteChain, true) + addEntry, err := lc.fetchParent( + entry, whoseCommitChain, lntypes.Remote, + ) if err != nil { return nil, err } skipThem[addEntry.HtlcIndex] = struct{}{} processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, remoteChain, true, mutateState) + nextHeight, whoseCommitChain, true, mutateState) } for _, entry := range view.theirUpdates { switch entry.EntryType { @@ -2695,8 +2710,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // Process fee updates, updating the current feePerKw. case FeeUpdate: processFeeUpdate( - entry, nextHeight, remoteChain, mutateState, - newView, + entry, nextHeight, whoseCommitChain, + mutateState, newView, ) continue } @@ -2705,19 +2720,23 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // and it hasn't been processed, yet, the increment our state // tracking the total number of satoshis we've sent within the // channel. - if mutateState && entry.EntryType == Settle && !remoteChain && + if mutateState && entry.EntryType == Settle && + whoseCommitChain.IsLocal() && entry.removeCommitHeightLocal == 0 { + lc.channelState.TotalMSatSent += entry.Amount } - addEntry, err := lc.fetchParent(entry, remoteChain, false) + addEntry, err := lc.fetchParent( + entry, whoseCommitChain, lntypes.Local, + ) if err != nil { return nil, err } skipUs[addEntry.HtlcIndex] = struct{}{} processRemoveEntry(entry, ourBalance, theirBalance, - nextHeight, remoteChain, false, mutateState) + nextHeight, whoseCommitChain, false, mutateState) } // Next we take a second pass through all the log entries, skipping any @@ -2730,7 +2749,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } processAddEntry(entry, ourBalance, theirBalance, nextHeight, - remoteChain, false, mutateState) + whoseCommitChain, false, mutateState) newView.ourUpdates = append(newView.ourUpdates, entry) } for _, entry := range view.theirUpdates { @@ -2740,7 +2759,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, } processAddEntry(entry, ourBalance, theirBalance, nextHeight, - remoteChain, true, mutateState) + whoseCommitChain, true, mutateState) newView.theirUpdates = append(newView.theirUpdates, entry) } @@ -2750,14 +2769,15 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, // fetchParent is a helper that looks up update log parent entries in the // appropriate log. func (lc *LightningChannel) fetchParent(entry *PaymentDescriptor, - remoteChain, remoteLog bool) (*PaymentDescriptor, error) { + whoseCommitChain, whoseUpdateLog lntypes.ChannelParty, +) (*PaymentDescriptor, error) { var ( updateLog *updateLog logName string ) - if remoteLog { + if whoseUpdateLog.IsRemote() { updateLog = lc.remoteUpdateLog logName = "remote" } else { @@ -2781,11 +2801,16 @@ func (lc *LightningChannel) fetchParent(entry *PaymentDescriptor, // The parent add height should never be zero at this point. If // that's the case we probably forgot to send a new commitment. - case remoteChain && addEntry.addCommitHeightRemote == 0: + case whoseCommitChain.IsRemote() && + addEntry.addCommitHeightRemote == 0: + return nil, fmt.Errorf("parent entry %d for update %d "+ "had zero remote add height", entry.ParentIndex, entry.LogIndex) - case !remoteChain && addEntry.addCommitHeightLocal == 0: + + case whoseCommitChain.IsLocal() && + addEntry.addCommitHeightLocal == 0: + return nil, fmt.Errorf("parent entry %d for update %d "+ "had zero local add height", entry.ParentIndex, entry.LogIndex) @@ -2798,15 +2823,16 @@ func (lc *LightningChannel) fetchParent(entry *PaymentDescriptor, // If the HTLC hasn't yet been committed in either chain, then the height it // was committed is updated. Keeping track of this inclusion height allows us to // later compact the log once the change is fully committed in both chains. -func processAddEntry(htlc *PaymentDescriptor, ourBalance, theirBalance *lnwire.MilliSatoshi, - nextHeight uint64, remoteChain bool, isIncoming, mutateState bool) { +func processAddEntry(htlc *PaymentDescriptor, ourBalance, + theirBalance *lnwire.MilliSatoshi, nextHeight uint64, + whoseCommitChain lntypes.ChannelParty, isIncoming, mutateState bool) { // If we're evaluating this entry for the remote chain (to create/view // a new commitment), then we'll may be updating the height this entry // was added to the chain. Otherwise, we may be updating the entry's // height w.r.t the local chain. var addHeight *uint64 - if remoteChain { + if whoseCommitChain.IsRemote() { addHeight = &htlc.addCommitHeightRemote } else { addHeight = &htlc.addCommitHeightLocal @@ -2837,10 +2863,10 @@ func processAddEntry(htlc *PaymentDescriptor, ourBalance, theirBalance *lnwire.M // is skipped. func processRemoveEntry(htlc *PaymentDescriptor, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - remoteChain bool, isIncoming, mutateState bool) { + whoseCommitChain lntypes.ChannelParty, isIncoming, mutateState bool) { var removeHeight *uint64 - if remoteChain { + if whoseCommitChain.IsRemote() { removeHeight = &htlc.removeCommitHeightRemote } else { removeHeight = &htlc.removeCommitHeightLocal @@ -2885,14 +2911,15 @@ func processRemoveEntry(htlc *PaymentDescriptor, ourBalance, // processFeeUpdate processes a log update that updates the current commitment // fee. func processFeeUpdate(feeUpdate *PaymentDescriptor, nextHeight uint64, - remoteChain bool, mutateState bool, view *htlcView) { + whoseCommitChain lntypes.ChannelParty, mutateState bool, view *htlcView, +) { // Fee updates are applied for all commitments after they are // sent/received, so we consider them being added and removed at the // same height. var addHeight *uint64 var removeHeight *uint64 - if remoteChain { + if whoseCommitChain.IsRemote() { addHeight = &feeUpdate.addCommitHeightRemote removeHeight = &feeUpdate.removeCommitHeightRemote } else { @@ -2945,7 +2972,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // sigJob will be generated and appended to the current batch. for _, htlc := range remoteCommitView.incomingHTLCs { if HtlcIsDust( - chanType, true, false, feePerKw, + chanType, true, lntypes.Remote, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -3014,7 +3041,7 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, } for _, htlc := range remoteCommitView.outgoingHTLCs { if HtlcIsDust( - chanType, false, false, feePerKw, + chanType, false, lntypes.Remote, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -3212,7 +3239,7 @@ func (lc *LightningChannel) createCommitDiff( // With the set of log updates mapped into wire messages, we'll now // convert the in-memory commit into a format suitable for writing to // disk. - diskCommit := newCommit.toDiskCommit(false) + diskCommit := newCommit.toDiskCommit(lntypes.Remote) return &channeldb.CommitDiff{ Commitment: *diskCommit, @@ -3463,12 +3490,13 @@ func (lc *LightningChannel) applyCommitFee( // PaymentDescriptor if we are validating in the state when adding a new HTLC, // or nil otherwise. func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, - ourLogCounter uint64, remoteChain bool, buffer BufferType, - predictOurAdd, predictTheirAdd *PaymentDescriptor) error { + ourLogCounter uint64, whoseCommitChain lntypes.ChannelParty, + buffer BufferType, predictOurAdd, predictTheirAdd *PaymentDescriptor, +) error { // First fetch the initial balance before applying any updates. commitChain := lc.localCommitChain - if remoteChain { + if whoseCommitChain.IsRemote() { commitChain = lc.remoteCommitChain } ourInitialBalance := commitChain.tip().ourBalance @@ -3488,7 +3516,8 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, } ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( - view, remoteChain, false, fn.None[chainfee.SatPerKWeight](), + view, whoseCommitChain, false, + fn.None[chainfee.SatPerKWeight](), ) if err != nil { return err @@ -3703,7 +3732,7 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // dare to fail hard here. We assume peers can deal with the empty sig // and continue channel operation. We log an error so that the bug // causing this can be tracked down. - if !lc.oweCommitment(true) { + if !lc.oweCommitment(lntypes.Local) { lc.log.Errorf("sending empty commit sig") } @@ -3737,8 +3766,8 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // point all updates will have to get locked-in so we enforce the // minimum requirement. err := lc.validateCommitmentSanity( - remoteACKedIndex, lc.localUpdateLog.logIndex, true, NoBuffer, - nil, nil, + remoteACKedIndex, lc.localUpdateLog.logIndex, lntypes.Remote, + NoBuffer, nil, nil, ) if err != nil { return nil, err @@ -3748,7 +3777,7 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // used within fetchCommitmentView to derive all the keys necessary to // construct the commitment state. keyRing := DeriveCommitmentKeys( - commitPoint, false, lc.channelState.ChanType, + commitPoint, lntypes.Remote, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) @@ -3760,8 +3789,9 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { // _all_ of our changes (pending or committed) but only the remote // node's changes up to the last change we've ACK'd. newCommitView, err := lc.fetchCommitmentView( - true, lc.localUpdateLog.logIndex, lc.localUpdateLog.htlcCounter, - remoteACKedIndex, remoteHtlcIndex, keyRing, + lntypes.Remote, lc.localUpdateLog.logIndex, + lc.localUpdateLog.htlcCounter, remoteACKedIndex, + remoteHtlcIndex, keyRing, ) if err != nil { return nil, err @@ -4255,14 +4285,14 @@ func (lc *LightningChannel) ProcessChanSyncMsg( // // If the updateState boolean is set true, the add and remove heights of the // HTLCs will be set to the next commitment height. -func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, - updateState bool, dryRunFee fn.Option[chainfee.SatPerKWeight]) ( - lnwire.MilliSatoshi, lnwire.MilliSatoshi, lntypes.WeightUnit, - *htlcView, error) { +func (lc *LightningChannel) computeView(view *htlcView, + whoseCommitChain lntypes.ChannelParty, updateState bool, + dryRunFee fn.Option[chainfee.SatPerKWeight]) (lnwire.MilliSatoshi, + lnwire.MilliSatoshi, lntypes.WeightUnit, *htlcView, error) { commitChain := lc.localCommitChain dustLimit := lc.channelState.LocalChanCfg.DustLimit - if remoteChain { + if whoseCommitChain.IsRemote() { commitChain = lc.remoteCommitChain dustLimit = lc.channelState.RemoteChanCfg.DustLimit } @@ -4298,7 +4328,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, // updates are found in the logs, the commitment fee rate should be // changed, so we'll also set the feePerKw to this new value. filteredHTLCView, err := lc.evaluateHTLCView(view, &ourBalance, - &theirBalance, nextHeight, remoteChain, updateState) + &theirBalance, nextHeight, whoseCommitChain, updateState) if err != nil { return 0, 0, 0, nil, err } @@ -4328,7 +4358,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, var totalHtlcWeight lntypes.WeightUnit for _, htlc := range filteredHTLCView.ourUpdates { if HtlcIsDust( - lc.channelState.ChanType, false, !remoteChain, + lc.channelState.ChanType, false, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -4339,7 +4369,7 @@ func (lc *LightningChannel) computeView(view *htlcView, remoteChain bool, } for _, htlc := range filteredHTLCView.theirUpdates { if HtlcIsDust( - lc.channelState.ChanType, true, !remoteChain, + lc.channelState.ChanType, true, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -4681,7 +4711,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { // reliable, because it could be that we've sent out a new sig, but the // remote hasn't received it yet. We could then falsely assume that they // should add our updates to their remote commitment tx. - if !lc.oweCommitment(false) { + if !lc.oweCommitment(lntypes.Remote) { lc.log.Warnf("empty commit sig message received") } @@ -4698,8 +4728,8 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { // the UpdateAddHTLC msg from our peer prior to receiving the // commit-sig). err := lc.validateCommitmentSanity( - lc.remoteUpdateLog.logIndex, localACKedIndex, false, NoBuffer, - nil, nil, + lc.remoteUpdateLog.logIndex, localACKedIndex, lntypes.Local, + NoBuffer, nil, nil, ) if err != nil { return err @@ -4716,7 +4746,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { } commitPoint := input.ComputeCommitmentPoint(commitSecret[:]) keyRing := DeriveCommitmentKeys( - commitPoint, true, lc.channelState.ChanType, + commitPoint, lntypes.Local, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) @@ -4725,7 +4755,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { // we know of in the remote node's HTLC log, but only our local changes // up to the last change the remote node has ACK'd. localCommitmentView, err := lc.fetchCommitmentView( - false, localACKedIndex, localHtlcIndex, + lntypes.Local, localACKedIndex, localHtlcIndex, lc.remoteUpdateLog.logIndex, lc.remoteUpdateLog.htlcCounter, keyRing, ) @@ -4962,11 +4992,11 @@ func (lc *LightningChannel) IsChannelClean() bool { // Now check that both local and remote commitments are signing the // same updates. - if lc.oweCommitment(true) { + if lc.oweCommitment(lntypes.Local) { return false } - if lc.oweCommitment(false) { + if lc.oweCommitment(lntypes.Remote) { return false } @@ -4983,7 +5013,7 @@ func (lc *LightningChannel) OweCommitment() bool { lc.RLock() defer lc.RUnlock() - return lc.oweCommitment(true) + return lc.oweCommitment(lntypes.Local) } // NeedCommitment returns a boolean value reflecting whether we are waiting on @@ -4994,12 +5024,12 @@ func (lc *LightningChannel) NeedCommitment() bool { lc.RLock() defer lc.RUnlock() - return lc.oweCommitment(false) + return lc.oweCommitment(lntypes.Remote) } // oweCommitment is the internal version of OweCommitment. This function expects // to be executed with a lock held. -func (lc *LightningChannel) oweCommitment(local bool) bool { +func (lc *LightningChannel) oweCommitment(issuer lntypes.ChannelParty) bool { var ( remoteUpdatesPending, localUpdatesPending bool @@ -5009,7 +5039,7 @@ func (lc *LightningChannel) oweCommitment(local bool) bool { perspective string ) - if local { + if issuer.IsLocal() { perspective = "local" // There are local updates pending if our local update log is @@ -5091,7 +5121,7 @@ func (lc *LightningChannel) RevokeCurrentCommitment() (*lnwire.RevokeAndAck, // Additionally, generate a channel delta for this state transition for // persistent storage. chainTail := lc.localCommitChain.tail() - newCommitment := chainTail.toDiskCommit(true) + newCommitment := chainTail.toDiskCommit(lntypes.Local) // Get the unsigned acked remotes updates that are currently in memory. // We need them after a restart to sync our remote commitment with what @@ -5501,7 +5531,7 @@ func (lc *LightningChannel) addHTLC(htlc *lnwire.UpdateAddHTLC, // commitment tx. // // NOTE: This over-estimates the dust exposure. -func (lc *LightningChannel) GetDustSum(remote bool, +func (lc *LightningChannel) GetDustSum(whoseCommit lntypes.ChannelParty, dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { lc.RLock() @@ -5511,7 +5541,7 @@ func (lc *LightningChannel) GetDustSum(remote bool, dustLimit := lc.channelState.LocalChanCfg.DustLimit commit := lc.channelState.LocalCommitment - if remote { + if whoseCommit.IsRemote() { // Calculate dust sum on the remote's commitment. dustLimit = lc.channelState.RemoteChanCfg.DustLimit commit = lc.channelState.RemoteCommitment @@ -5535,7 +5565,7 @@ func (lc *LightningChannel) GetDustSum(remote bool, // If the satoshi amount is under the dust limit, add the msat // amount to the dust sum. if HtlcIsDust( - chanType, false, !remote, feeRate, amt, dustLimit, + chanType, false, whoseCommit, feeRate, amt, dustLimit, ) { dustSum += pd.Amount @@ -5554,7 +5584,8 @@ func (lc *LightningChannel) GetDustSum(remote bool, // If the satoshi amount is under the dust limit, add the msat // amount to the dust sum. if HtlcIsDust( - chanType, true, !remote, feeRate, amt, dustLimit, + chanType, true, whoseCommit, feeRate, + amt, dustLimit, ) { dustSum += pd.Amount @@ -5641,7 +5672,7 @@ func (lc *LightningChannel) validateAddHtlc(pd *PaymentDescriptor, // First we'll check whether this HTLC can be added to the remote // commitment transaction without violation any of the constraints. err := lc.validateCommitmentSanity( - remoteACKedIndex, lc.localUpdateLog.logIndex, true, + remoteACKedIndex, lc.localUpdateLog.logIndex, lntypes.Remote, buffer, pd, nil, ) if err != nil { @@ -5655,7 +5686,7 @@ func (lc *LightningChannel) validateAddHtlc(pd *PaymentDescriptor, // possible for us to add the HTLC. err = lc.validateCommitmentSanity( lc.remoteUpdateLog.logIndex, lc.localUpdateLog.logIndex, - false, buffer, pd, nil, + lntypes.Local, buffer, pd, nil, ) if err != nil { return err @@ -5696,8 +5727,8 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.UpdateAddHTLC) (uint64, err // we use it here. The current lightning protocol does not allow to // reject ADDs already sent by the peer. err := lc.validateCommitmentSanity( - lc.remoteUpdateLog.logIndex, localACKedIndex, false, NoBuffer, - nil, pd, + lc.remoteUpdateLog.logIndex, localACKedIndex, lntypes.Local, + NoBuffer, nil, pd, ) if err != nil { return 0, err @@ -6195,9 +6226,9 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si // First, we'll generate the commitment point and the revocation point // so we can re-construct the HTLC state and also our payment key. - isOurCommit := false + commitType := lntypes.Remote keyRing := DeriveCommitmentKeys( - commitPoint, isOurCommit, chanState.ChanType, + commitPoint, commitType, chanState.ChanType, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) @@ -6209,7 +6240,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si } isRemoteInitiator := !chanState.IsInitiator htlcResolutions, err := extractHtlcResolutions( - chainfee.SatPerKWeight(remoteCommit.FeePerKw), isOurCommit, + chainfee.SatPerKWeight(remoteCommit.FeePerKw), commitType, signer, remoteCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitSpend.SpendingTx, chanState.ChanType, isRemoteInitiator, leaseExpiry, @@ -6328,7 +6359,7 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Si } anchorResolution, err := NewAnchorResolution( - chanState, commitTxBroadcast, keyRing, false, + chanState, commitTxBroadcast, keyRing, lntypes.Remote, ) if err != nil { return nil, err @@ -6465,7 +6496,7 @@ func newOutgoingHtlcResolution(signer input.Signer, localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, - localCommit, isCommitFromInitiator bool, + whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, chanType channeldb.ChannelType) (*OutgoingHtlcResolution, error) { op := wire.OutPoint{ @@ -6476,7 +6507,7 @@ func newOutgoingHtlcResolution(signer input.Signer, // First, we'll re-generate the script used to send the HTLC to the // remote party within their commitment transaction. htlcScriptInfo, err := genHtlcScript( - chanType, false, localCommit, htlc.RefundTimeout, htlc.RHash, + chanType, false, whoseCommit, htlc.RefundTimeout, htlc.RHash, keyRing, ) if err != nil { @@ -6497,7 +6528,7 @@ func newOutgoingHtlcResolution(signer input.Signer, // If we're spending this HTLC output from the remote node's // commitment, then we won't need to go to the second level as our // outputs don't have a CSV delay. - if !localCommit { + if whoseCommit.IsRemote() { // With the script generated, we can completely populated the // SignDescriptor needed to sweep the output. prevFetcher := txscript.NewCannedPrevOutputFetcher( @@ -6717,7 +6748,8 @@ func newIncomingHtlcResolution(signer input.Signer, localChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, htlc *channeldb.HTLC, keyRing *CommitmentKeyRing, feePerKw chainfee.SatPerKWeight, csvDelay, leaseExpiry uint32, - localCommit, isCommitFromInitiator bool, chanType channeldb.ChannelType) ( + whoseCommit lntypes.ChannelParty, isCommitFromInitiator bool, + chanType channeldb.ChannelType) ( *IncomingHtlcResolution, error) { op := wire.OutPoint{ @@ -6728,7 +6760,7 @@ func newIncomingHtlcResolution(signer input.Signer, // First, we'll re-generate the script the remote party used to // send the HTLC to us in their commitment transaction. scriptInfo, err := genHtlcScript( - chanType, true, localCommit, htlc.RefundTimeout, htlc.RHash, + chanType, true, whoseCommit, htlc.RefundTimeout, htlc.RHash, keyRing, ) if err != nil { @@ -6749,7 +6781,7 @@ func newIncomingHtlcResolution(signer input.Signer, // If we're spending this output from the remote node's commitment, // then we can skip the second layer and spend the output directly. - if !localCommit { + if whoseCommit.IsRemote() { // With the script generated, we can completely populated the // SignDescriptor needed to sweep the output. prevFetcher := txscript.NewCannedPrevOutputFetcher( @@ -6976,8 +7008,9 @@ func (r *OutgoingHtlcResolution) HtlcPoint() wire.OutPoint { // extractHtlcResolutions creates a series of outgoing HTLC resolutions, and // the local key used when generating the HTLC scrips. This function is to be // used in two cases: force close, or a unilateral close. -func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, - signer input.Signer, htlcs []channeldb.HTLC, keyRing *CommitmentKeyRing, +func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, + whoseCommit lntypes.ChannelParty, signer input.Signer, + htlcs []channeldb.HTLC, keyRing *CommitmentKeyRing, localChanCfg, remoteChanCfg *channeldb.ChannelConfig, commitTx *wire.MsgTx, chanType channeldb.ChannelType, isCommitFromInitiator bool, leaseExpiry uint32) (*HtlcResolutions, error) { @@ -6985,7 +7018,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, // TODO(roasbeef): don't need to swap csv delay? dustLimit := remoteChanCfg.DustLimit csvDelay := remoteChanCfg.CsvDelay - if ourCommit { + if whoseCommit.IsLocal() { dustLimit = localChanCfg.DustLimit csvDelay = localChanCfg.CsvDelay } @@ -6999,7 +7032,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, // transaction, as these don't have a corresponding output // within the commitment transaction. if HtlcIsDust( - chanType, htlc.Incoming, ourCommit, feePerKw, + chanType, htlc.Incoming, whoseCommit, feePerKw, htlc.Amt.ToSatoshis(), dustLimit, ) { @@ -7014,7 +7047,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, ihr, err := newIncomingHtlcResolution( signer, localChanCfg, commitTx, &htlc, keyRing, feePerKw, uint32(csvDelay), leaseExpiry, - ourCommit, isCommitFromInitiator, chanType, + whoseCommit, isCommitFromInitiator, chanType, ) if err != nil { return nil, fmt.Errorf("incoming resolution "+ @@ -7027,7 +7060,7 @@ func extractHtlcResolutions(feePerKw chainfee.SatPerKWeight, ourCommit bool, ohr, err := newOutgoingHtlcResolution( signer, localChanCfg, commitTx, &htlc, keyRing, - feePerKw, uint32(csvDelay), leaseExpiry, ourCommit, + feePerKw, uint32(csvDelay), leaseExpiry, whoseCommit, isCommitFromInitiator, chanType, ) if err != nil { @@ -7163,7 +7196,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, } commitPoint := input.ComputeCommitmentPoint(revocation[:]) keyRing := DeriveCommitmentKeys( - commitPoint, true, chanState.ChanType, + commitPoint, lntypes.Local, chanState.ChanType, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) @@ -7261,8 +7294,8 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, // use what we have in our latest state when extracting resolutions. localCommit := chanState.LocalCommitment htlcResolutions, err := extractHtlcResolutions( - chainfee.SatPerKWeight(localCommit.FeePerKw), true, signer, - localCommit.Htlcs, keyRing, &chanState.LocalChanCfg, + chainfee.SatPerKWeight(localCommit.FeePerKw), lntypes.Local, + signer, localCommit.Htlcs, keyRing, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, commitTx, chanState.ChanType, chanState.IsInitiator, leaseExpiry, ) @@ -7271,7 +7304,7 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, } anchorResolution, err := NewAnchorResolution( - chanState, commitTx, keyRing, true, + chanState, commitTx, keyRing, lntypes.Local, ) if err != nil { return nil, fmt.Errorf("unable to gen anchor "+ @@ -7561,12 +7594,12 @@ func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions, } localCommitPoint := input.ComputeCommitmentPoint(revocation[:]) localKeyRing := DeriveCommitmentKeys( - localCommitPoint, true, lc.channelState.ChanType, + localCommitPoint, lntypes.Local, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) localRes, err := NewAnchorResolution( lc.channelState, lc.channelState.LocalCommitment.CommitTx, - localKeyRing, true, + localKeyRing, lntypes.Local, ) if err != nil { return nil, err @@ -7575,13 +7608,13 @@ func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions, // Add anchor for remote commitment tx, if any. remoteKeyRing := DeriveCommitmentKeys( - lc.channelState.RemoteCurrentRevocation, false, + lc.channelState.RemoteCurrentRevocation, lntypes.Remote, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) remoteRes, err := NewAnchorResolution( lc.channelState, lc.channelState.RemoteCommitment.CommitTx, - remoteKeyRing, false, + remoteKeyRing, lntypes.Remote, ) if err != nil { return nil, err @@ -7596,14 +7629,14 @@ func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions, if remotePendingCommit != nil { pendingRemoteKeyRing := DeriveCommitmentKeys( - lc.channelState.RemoteNextRevocation, false, + lc.channelState.RemoteNextRevocation, lntypes.Remote, lc.channelState.ChanType, &lc.channelState.LocalChanCfg, &lc.channelState.RemoteChanCfg, ) remotePendingRes, err := NewAnchorResolution( lc.channelState, remotePendingCommit.Commitment.CommitTx, - pendingRemoteKeyRing, false, + pendingRemoteKeyRing, lntypes.Remote, ) if err != nil { return nil, err @@ -7618,7 +7651,7 @@ func (lc *LightningChannel) NewAnchorResolutions() (*AnchorResolutions, // local anchor. func NewAnchorResolution(chanState *channeldb.OpenChannel, commitTx *wire.MsgTx, keyRing *CommitmentKeyRing, - isLocalCommit bool) (*AnchorResolution, error) { + whoseCommit lntypes.ChannelParty) (*AnchorResolution, error) { // Return nil resolution if the channel has no anchors. if !chanState.ChanType.HasAnchors() { @@ -7636,7 +7669,7 @@ func NewAnchorResolution(chanState *channeldb.OpenChannel, if err != nil { return nil, err } - if chanState.ChanType.IsTaproot() && !isLocalCommit { + if chanState.ChanType.IsTaproot() && whoseCommit.IsRemote() { //nolint:ineffassign localAnchor, remoteAnchor = remoteAnchor, localAnchor } @@ -7690,7 +7723,7 @@ func NewAnchorResolution(chanState *channeldb.OpenChannel, // For anchor outputs with taproot channels, the key desc is // also different: we'll just re-use our local delay base point // (which becomes our to local output). - if isLocalCommit { + if whoseCommit.IsLocal() { // In addition to the sign method, we'll also need to // ensure that the single tweak is set, as with the // current formulation, we'll need to use two levels of @@ -7777,12 +7810,12 @@ func (lc *LightningChannel) availableBalance( // add updates concurrently, causing our balance to go down if we're // the initiator, but this is a problem on the protocol level. ourLocalCommitBalance, commitWeight := lc.availableCommitmentBalance( - htlcView, false, buffer, + htlcView, lntypes.Local, buffer, ) // Do the same calculation from the remote commitment point of view. ourRemoteCommitBalance, _ := lc.availableCommitmentBalance( - htlcView, true, buffer, + htlcView, lntypes.Remote, buffer, ) // Return which ever balance is lowest. @@ -7800,15 +7833,16 @@ func (lc *LightningChannel) availableBalance( // commitment, increasing the commitment fee we must pay as an initiator, // eating into our balance. It will make sure we won't violate the channel // reserve constraints for this amount. -func (lc *LightningChannel) availableCommitmentBalance( - view *htlcView, remoteChain bool, - buffer BufferType) (lnwire.MilliSatoshi, lntypes.WeightUnit) { +func (lc *LightningChannel) availableCommitmentBalance(view *htlcView, + whoseCommitChain lntypes.ChannelParty, buffer BufferType) ( + lnwire.MilliSatoshi, lntypes.WeightUnit) { // Compute the current balances for this commitment. This will take // into account HTLCs to determine the commit weight, which the // initiator must pay the fee for. ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( - view, remoteChain, false, fn.None[chainfee.SatPerKWeight](), + view, whoseCommitChain, false, + fn.None[chainfee.SatPerKWeight](), ) if err != nil { lc.log.Errorf("Unable to fetch available balance: %v", err) @@ -7894,7 +7928,7 @@ func (lc *LightningChannel) availableCommitmentBalance( // If we are looking at the remote commitment, we must use the remote // dust limit and the fee for adding an HTLC success transaction. - if remoteChain { + if whoseCommitChain.IsRemote() { dustlimit = lnwire.NewMSatFromSatoshis( lc.channelState.RemoteChanCfg.DustLimit, ) @@ -8031,7 +8065,7 @@ func (lc *LightningChannel) CommitFeeTotalAt( // Compute the local commitment's weight. _, _, localWeight, _, err := lc.computeView( - localHtlcView, false, false, dryRunFee, + localHtlcView, lntypes.Local, false, dryRunFee, ) if err != nil { return 0, 0, err @@ -8045,7 +8079,7 @@ func (lc *LightningChannel) CommitFeeTotalAt( // Compute the remote commitment's weight. _, _, remoteWeight, _, err := lc.computeView( - remoteHtlcView, true, false, dryRunFee, + remoteHtlcView, lntypes.Remote, false, dryRunFee, ) if err != nil { return 0, 0, err @@ -8455,17 +8489,12 @@ func (lc *LightningChannel) MarkBorked() error { // for it to confirm before taking any further action. It takes a boolean which // indicates whether we initiated the close. func (lc *LightningChannel) MarkCommitmentBroadcasted(tx *wire.MsgTx, - locallyInitiated bool) error { + closer lntypes.ChannelParty) error { lc.Lock() defer lc.Unlock() - party := lntypes.Remote - if locallyInitiated { - party = lntypes.Local - } - - return lc.channelState.MarkCommitmentBroadcasted(tx, party) + return lc.channelState.MarkCommitmentBroadcasted(tx, closer) } // MarkCoopBroadcasted marks the channel as a cooperative close transaction has @@ -8473,17 +8502,12 @@ func (lc *LightningChannel) MarkCommitmentBroadcasted(tx *wire.MsgTx, // taking any further action. It takes a locally initiated bool which is true // if we initiated the cooperative close. func (lc *LightningChannel) MarkCoopBroadcasted(tx *wire.MsgTx, - localInitiated bool) error { + closer lntypes.ChannelParty) error { lc.Lock() defer lc.Unlock() - party := lntypes.Remote - if localInitiated { - party = lntypes.Local - } - - return lc.channelState.MarkCoopBroadcasted(tx, party) + return lc.channelState.MarkCoopBroadcasted(tx, closer) } // MarkShutdownSent persists the given ShutdownInfo. The existence of the diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 185bdf87a..330d8d130 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -5196,7 +5196,7 @@ func TestChanCommitWeightDustHtlcs(t *testing.T) { lc.localUpdateLog.logIndex) _, w := lc.availableCommitmentBalance( - htlcView, true, FeeBuffer, + htlcView, lntypes.Remote, FeeBuffer, ) return w @@ -7985,11 +7985,11 @@ func TestChannelFeeRateFloor(t *testing.T) { // TestFetchParent tests lookup of an entry's parent in the appropriate log. func TestFetchParent(t *testing.T) { tests := []struct { - name string - remoteChain bool - remoteLog bool - localEntries []*PaymentDescriptor - remoteEntries []*PaymentDescriptor + name string + whoseCommitChain lntypes.ChannelParty + whoseUpdateLog lntypes.ChannelParty + localEntries []*PaymentDescriptor + remoteEntries []*PaymentDescriptor // parentIndex is the parent index of the entry that we will // lookup with fetch parent. @@ -8003,22 +8003,22 @@ func TestFetchParent(t *testing.T) { expectedIndex uint64 }{ { - name: "not found in remote log", - localEntries: nil, - remoteEntries: nil, - remoteChain: true, - remoteLog: true, - parentIndex: 0, - expectErr: true, + name: "not found in remote log", + localEntries: nil, + remoteEntries: nil, + whoseCommitChain: lntypes.Remote, + whoseUpdateLog: lntypes.Remote, + parentIndex: 0, + expectErr: true, }, { - name: "not found in local log", - localEntries: nil, - remoteEntries: nil, - remoteChain: false, - remoteLog: false, - parentIndex: 0, - expectErr: true, + name: "not found in local log", + localEntries: nil, + remoteEntries: nil, + whoseCommitChain: lntypes.Local, + whoseUpdateLog: lntypes.Local, + parentIndex: 0, + expectErr: true, }, { name: "remote log + chain, remote add height 0", @@ -8038,10 +8038,10 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 0, }, }, - remoteChain: true, - remoteLog: true, - parentIndex: 1, - expectErr: true, + whoseCommitChain: lntypes.Remote, + whoseUpdateLog: lntypes.Remote, + parentIndex: 1, + expectErr: true, }, { name: "remote log, local chain, local add height 0", @@ -8060,11 +8060,11 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 100, }, }, - localEntries: nil, - remoteChain: false, - remoteLog: true, - parentIndex: 1, - expectErr: true, + localEntries: nil, + whoseCommitChain: lntypes.Local, + whoseUpdateLog: lntypes.Remote, + parentIndex: 1, + expectErr: true, }, { name: "local log + chain, local add height 0", @@ -8083,11 +8083,11 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 100, }, }, - remoteEntries: nil, - remoteChain: false, - remoteLog: false, - parentIndex: 1, - expectErr: true, + remoteEntries: nil, + whoseCommitChain: lntypes.Local, + whoseUpdateLog: lntypes.Local, + parentIndex: 1, + expectErr: true, }, { @@ -8107,11 +8107,11 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 0, }, }, - remoteEntries: nil, - remoteChain: true, - remoteLog: false, - parentIndex: 1, - expectErr: true, + remoteEntries: nil, + whoseCommitChain: lntypes.Remote, + whoseUpdateLog: lntypes.Local, + parentIndex: 1, + expectErr: true, }, { name: "remote log found", @@ -8131,11 +8131,11 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 100, }, }, - remoteChain: true, - remoteLog: true, - parentIndex: 1, - expectErr: false, - expectedIndex: 2, + whoseCommitChain: lntypes.Remote, + whoseUpdateLog: lntypes.Remote, + parentIndex: 1, + expectErr: false, + expectedIndex: 2, }, { name: "local log found", @@ -8154,12 +8154,12 @@ func TestFetchParent(t *testing.T) { addCommitHeightRemote: 100, }, }, - remoteEntries: nil, - remoteChain: false, - remoteLog: false, - parentIndex: 1, - expectErr: false, - expectedIndex: 2, + remoteEntries: nil, + whoseCommitChain: lntypes.Local, + whoseUpdateLog: lntypes.Local, + parentIndex: 1, + expectErr: false, + expectedIndex: 2, }, } @@ -8186,8 +8186,8 @@ func TestFetchParent(t *testing.T) { &PaymentDescriptor{ ParentIndex: test.parentIndex, }, - test.remoteChain, - test.remoteLog, + test.whoseCommitChain, + test.whoseUpdateLog, ) gotErr := err != nil if test.expectErr != gotErr { @@ -8245,11 +8245,11 @@ func TestEvaluateView(t *testing.T) { ) tests := []struct { - name string - ourHtlcs []*PaymentDescriptor - theirHtlcs []*PaymentDescriptor - remoteChain bool - mutateState bool + name string + ourHtlcs []*PaymentDescriptor + theirHtlcs []*PaymentDescriptor + whoseCommitChain lntypes.ChannelParty + mutateState bool // ourExpectedHtlcs is the set of our htlcs that we expect in // the htlc view once it has been evaluated. We just store @@ -8276,9 +8276,9 @@ func TestEvaluateView(t *testing.T) { expectSent lnwire.MilliSatoshi }{ { - name: "our fee update is applied", - remoteChain: false, - mutateState: false, + name: "our fee update is applied", + whoseCommitChain: lntypes.Local, + mutateState: false, ourHtlcs: []*PaymentDescriptor{ { Amount: ourFeeUpdateAmt, @@ -8293,10 +8293,10 @@ func TestEvaluateView(t *testing.T) { expectSent: 0, }, { - name: "their fee update is applied", - remoteChain: false, - mutateState: false, - ourHtlcs: []*PaymentDescriptor{}, + name: "their fee update is applied", + whoseCommitChain: lntypes.Local, + mutateState: false, + ourHtlcs: []*PaymentDescriptor{}, theirHtlcs: []*PaymentDescriptor{ { Amount: theirFeeUpdateAmt, @@ -8311,9 +8311,9 @@ func TestEvaluateView(t *testing.T) { }, { // We expect unresolved htlcs to to remain in the view. - name: "htlcs adds without settles", - remoteChain: false, - mutateState: false, + name: "htlcs adds without settles", + whoseCommitChain: lntypes.Local, + mutateState: false, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8345,9 +8345,9 @@ func TestEvaluateView(t *testing.T) { expectSent: 0, }, { - name: "our htlc settled, state mutated", - remoteChain: false, - mutateState: true, + name: "our htlc settled, state mutated", + whoseCommitChain: lntypes.Local, + mutateState: true, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8380,9 +8380,9 @@ func TestEvaluateView(t *testing.T) { expectSent: htlcAddAmount, }, { - name: "our htlc settled, state not mutated", - remoteChain: false, - mutateState: false, + name: "our htlc settled, state not mutated", + whoseCommitChain: lntypes.Local, + mutateState: false, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8415,9 +8415,9 @@ func TestEvaluateView(t *testing.T) { expectSent: 0, }, { - name: "their htlc settled, state mutated", - remoteChain: false, - mutateState: true, + name: "their htlc settled, state mutated", + whoseCommitChain: lntypes.Local, + mutateState: true, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8458,9 +8458,10 @@ func TestEvaluateView(t *testing.T) { expectSent: 0, }, { - name: "their htlc settled, state not mutated", - remoteChain: false, - mutateState: false, + name: "their htlc settled, state not mutated", + + whoseCommitChain: lntypes.Local, + mutateState: false, ourHtlcs: []*PaymentDescriptor{ { HtlcIndex: 0, @@ -8543,7 +8544,7 @@ func TestEvaluateView(t *testing.T) { // Evaluate the htlc view, mutate as test expects. result, err := lc.evaluateHTLCView( view, &ourBalance, &theirBalance, nextHeight, - test.remoteChain, test.mutateState, + test.whoseCommitChain, test.mutateState, ) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -8631,12 +8632,12 @@ func TestProcessFeeUpdate(t *testing.T) { ) tests := []struct { - name string - startHeights heights - expectedHeights heights - remoteChain bool - mutate bool - expectedFee chainfee.SatPerKWeight + name string + startHeights heights + expectedHeights heights + whoseCommitChain lntypes.ChannelParty + mutate bool + expectedFee chainfee.SatPerKWeight }{ { // Looking at local chain, local add is non-zero so @@ -8654,9 +8655,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: 0, remoteRemove: height, }, - remoteChain: false, - mutate: false, - expectedFee: feePerKw, + whoseCommitChain: lntypes.Local, + mutate: false, + expectedFee: feePerKw, }, { // Looking at local chain, local add is zero so the @@ -8675,9 +8676,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: height, remoteRemove: 0, }, - remoteChain: false, - mutate: false, - expectedFee: ourFeeUpdatePerSat, + whoseCommitChain: lntypes.Local, + mutate: false, + expectedFee: ourFeeUpdatePerSat, }, { // Looking at remote chain, the remote add height is @@ -8696,9 +8697,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: 0, remoteRemove: 0, }, - remoteChain: true, - mutate: false, - expectedFee: ourFeeUpdatePerSat, + whoseCommitChain: lntypes.Remote, + mutate: false, + expectedFee: ourFeeUpdatePerSat, }, { // Looking at remote chain, the remote add height is @@ -8717,9 +8718,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: height, remoteRemove: 0, }, - remoteChain: true, - mutate: false, - expectedFee: feePerKw, + whoseCommitChain: lntypes.Remote, + mutate: false, + expectedFee: feePerKw, }, { // Local add height is non-zero, so the update has @@ -8738,9 +8739,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: 0, remoteRemove: height, }, - remoteChain: false, - mutate: true, - expectedFee: feePerKw, + whoseCommitChain: lntypes.Local, + mutate: true, + expectedFee: feePerKw, }, { // Local add is zero and we are looking at our local @@ -8760,9 +8761,9 @@ func TestProcessFeeUpdate(t *testing.T) { remoteAdd: 0, remoteRemove: 0, }, - remoteChain: false, - mutate: true, - expectedFee: ourFeeUpdatePerSat, + whoseCommitChain: lntypes.Local, + mutate: true, + expectedFee: ourFeeUpdatePerSat, }, } @@ -8786,7 +8787,7 @@ func TestProcessFeeUpdate(t *testing.T) { feePerKw: chainfee.SatPerKWeight(feePerKw), } processFeeUpdate( - update, nextHeight, test.remoteChain, + update, nextHeight, test.whoseCommitChain, test.mutate, view, ) @@ -8841,7 +8842,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { tests := []struct { name string startHeights heights - remoteChain bool + whoseCommitChain lntypes.ChannelParty isIncoming bool mutateState bool ourExpectedBalance lnwire.MilliSatoshi @@ -8857,7 +8858,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -8878,7 +8879,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -8899,7 +8900,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: true, mutateState: false, ourExpectedBalance: startBalance, @@ -8920,7 +8921,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: true, mutateState: true, ourExpectedBalance: startBalance, @@ -8942,7 +8943,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance - updateAmount, @@ -8963,7 +8964,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: true, ourExpectedBalance: startBalance - updateAmount, @@ -8984,7 +8985,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: removeHeight, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -9005,7 +9006,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: removeHeight, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -9028,7 +9029,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: true, mutateState: false, ourExpectedBalance: startBalance + updateAmount, @@ -9051,7 +9052,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance, @@ -9074,7 +9075,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: true, mutateState: false, ourExpectedBalance: startBalance, @@ -9097,7 +9098,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: false, mutateState: false, ourExpectedBalance: startBalance + updateAmount, @@ -9122,7 +9123,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: false, + whoseCommitChain: lntypes.Local, isIncoming: true, mutateState: true, ourExpectedBalance: startBalance + updateAmount, @@ -9147,7 +9148,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { localRemove: 0, remoteRemove: 0, }, - remoteChain: true, + whoseCommitChain: lntypes.Remote, isIncoming: true, mutateState: true, ourExpectedBalance: startBalance + updateAmount, @@ -9196,7 +9197,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { process( update, &ourBalance, &theirBalance, nextHeight, - test.remoteChain, test.isIncoming, + test.whoseCommitChain, test.isIncoming, test.mutateState, ) @@ -9752,11 +9753,11 @@ func testGetDustSum(t *testing.T, chantype channeldb.ChannelType) { expRemote lnwire.MilliSatoshi) { localDustSum := c.GetDustSum( - false, fn.None[chainfee.SatPerKWeight](), + lntypes.Local, fn.None[chainfee.SatPerKWeight](), ) require.Equal(t, expLocal, localDustSum) remoteDustSum := c.GetDustSum( - true, fn.None[chainfee.SatPerKWeight](), + lntypes.Remote, fn.None[chainfee.SatPerKWeight](), ) require.Equal(t, expRemote, remoteDustSum) } @@ -9910,8 +9911,9 @@ func deriveDummyRetributionParams(chanState *channeldb.OpenChannel) (uint32, config := chanState.RemoteChanCfg commitHash := chanState.RemoteCommitment.CommitTx.TxHash() keyRing := DeriveCommitmentKeys( - config.RevocationBasePoint.PubKey, false, chanState.ChanType, - &chanState.LocalChanCfg, &chanState.RemoteChanCfg, + config.RevocationBasePoint.PubKey, lntypes.Remote, + chanState.ChanType, &chanState.LocalChanCfg, + &chanState.RemoteChanCfg, ) leaseExpiry := chanState.ThawHeight return leaseExpiry, keyRing, commitHash @@ -10378,7 +10380,7 @@ func TestExtractPayDescs(t *testing.T) { // NOTE: we use nil commitment key rings to avoid checking the htlc // scripts(`genHtlcScript`) as it should be tested independently. incomingPDs, outgoingPDs, err := lnChan.extractPayDescs( - 0, 0, htlcs, nil, nil, true, + 0, 0, htlcs, nil, nil, lntypes.Local, ) require.NoError(t, err) diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 1e1140fbc..2cf58f494 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -103,7 +103,7 @@ type CommitmentKeyRing struct { // of channel, and whether the commitment transaction is ours or the remote // peer's. func DeriveCommitmentKeys(commitPoint *btcec.PublicKey, - isOurCommit bool, chanType channeldb.ChannelType, + whoseCommit lntypes.ChannelParty, chanType channeldb.ChannelType, localChanCfg, remoteChanCfg *channeldb.ChannelConfig) *CommitmentKeyRing { tweaklessCommit := chanType.IsTweakless() @@ -111,7 +111,7 @@ func DeriveCommitmentKeys(commitPoint *btcec.PublicKey, // Depending on if this is our commit or not, we'll choose the correct // base point. localBasePoint := localChanCfg.PaymentBasePoint - if isOurCommit { + if whoseCommit.IsLocal() { localBasePoint = localChanCfg.DelayBasePoint } @@ -144,7 +144,7 @@ func DeriveCommitmentKeys(commitPoint *btcec.PublicKey, toRemoteBasePoint *btcec.PublicKey revocationBasePoint *btcec.PublicKey ) - if isOurCommit { + if whoseCommit.IsLocal() { toLocalBasePoint = localChanCfg.DelayBasePoint.PubKey toRemoteBasePoint = remoteChanCfg.PaymentBasePoint.PubKey revocationBasePoint = remoteChanCfg.RevocationBasePoint.PubKey @@ -169,7 +169,7 @@ func DeriveCommitmentKeys(commitPoint *btcec.PublicKey, // If this is not our commitment, the above ToRemoteKey will be // ours, and we blank out the local commitment tweak to // indicate that the key should not be tweaked when signing. - if !isOurCommit { + if whoseCommit.IsRemote() { keyRing.LocalCommitKeyTweak = nil } } else { @@ -686,20 +686,20 @@ type unsignedCommitmentTx struct { // passed in balances should be balances *before* subtracting any commitment // fees, but after anchor outputs. func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, - theirBalance lnwire.MilliSatoshi, isOurs bool, + theirBalance lnwire.MilliSatoshi, whoseCommit lntypes.ChannelParty, feePerKw chainfee.SatPerKWeight, height uint64, filteredHTLCView *htlcView, keyRing *CommitmentKeyRing) (*unsignedCommitmentTx, error) { dustLimit := cb.chanState.LocalChanCfg.DustLimit - if !isOurs { + if whoseCommit.IsRemote() { dustLimit = cb.chanState.RemoteChanCfg.DustLimit } numHTLCs := int64(0) for _, htlc := range filteredHTLCView.ourUpdates { if HtlcIsDust( - cb.chanState.ChanType, false, isOurs, feePerKw, + cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -710,7 +710,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } for _, htlc := range filteredHTLCView.theirUpdates { if HtlcIsDust( - cb.chanState.ChanType, true, isOurs, feePerKw, + cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -763,7 +763,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, if cb.chanState.ChanType.HasLeaseExpiration() { leaseExpiry = cb.chanState.ThawHeight } - if isOurs { + if whoseCommit.IsLocal() { commitTx, err = CreateCommitTx( cb.chanState.ChanType, fundingTxIn(cb.chanState), keyRing, &cb.chanState.LocalChanCfg, &cb.chanState.RemoteChanCfg, @@ -794,7 +794,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, cltvs := make([]uint32, len(commitTx.TxOut)) for _, htlc := range filteredHTLCView.ourUpdates { if HtlcIsDust( - cb.chanState.ChanType, false, isOurs, feePerKw, + cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -802,7 +802,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } err := addHTLC( - commitTx, isOurs, false, htlc, keyRing, + commitTx, whoseCommit, false, htlc, keyRing, cb.chanState.ChanType, ) if err != nil { @@ -812,7 +812,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } for _, htlc := range filteredHTLCView.theirUpdates { if HtlcIsDust( - cb.chanState.ChanType, true, isOurs, feePerKw, + cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, ) { @@ -820,7 +820,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } err := addHTLC( - commitTx, isOurs, true, htlc, keyRing, + commitTx, whoseCommit, true, htlc, keyRing, cb.chanState.ChanType, ) if err != nil { @@ -1003,8 +1003,9 @@ func CoopCloseBalance(chanType channeldb.ChannelType, isInitiator bool, // genSegwitV0HtlcScript generates the HTLC scripts for a normal segwit v0 // channel. func genSegwitV0HtlcScript(chanType channeldb.ChannelType, - isIncoming, ourCommit bool, timeout uint32, rHash [32]byte, - keyRing *CommitmentKeyRing) (*WitnessScriptDesc, error) { + isIncoming bool, whoseCommit lntypes.ChannelParty, timeout uint32, + rHash [32]byte, keyRing *CommitmentKeyRing, +) (*WitnessScriptDesc, error) { var ( witnessScript []byte @@ -1024,7 +1025,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // The HTLC is paying to us, and being applied to our commitment // transaction. So we need to use the receiver's version of the HTLC // script. - case isIncoming && ourCommit: + case isIncoming && whoseCommit.IsLocal(): witnessScript, err = input.ReceiverHTLCScript( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, @@ -1033,7 +1034,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // We're being paid via an HTLC by the remote party, and the HTLC is // being added to their commitment transaction, so we use the sender's // version of the HTLC script. - case isIncoming && !ourCommit: + case isIncoming && whoseCommit.IsRemote(): witnessScript, err = input.SenderHTLCScript( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, @@ -1042,7 +1043,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // We're sending an HTLC which is being added to our commitment // transaction. Therefore, we need to use the sender's version of the // HTLC script. - case !isIncoming && ourCommit: + case !isIncoming && whoseCommit.IsLocal(): witnessScript, err = input.SenderHTLCScript( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, @@ -1051,7 +1052,7 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // Finally, we're paying the remote party via an HTLC, which is being // added to their commitment transaction. Therefore, we use the // receiver's version of the HTLC script. - case !isIncoming && !ourCommit: + case !isIncoming && whoseCommit.IsRemote(): witnessScript, err = input.ReceiverHTLCScript( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, keyRing.RevocationKey, rHash[:], confirmedHtlcSpends, @@ -1076,9 +1077,9 @@ func genSegwitV0HtlcScript(chanType channeldb.ChannelType, // genTaprootHtlcScript generates the HTLC scripts for a taproot+musig2 // channel. -func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, - rHash [32]byte, - keyRing *CommitmentKeyRing) (*input.HtlcScriptTree, error) { +func genTaprootHtlcScript(isIncoming bool, whoseCommit lntypes.ChannelParty, + timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, +) (*input.HtlcScriptTree, error) { var ( htlcScriptTree *input.HtlcScriptTree @@ -1092,37 +1093,37 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, // The HTLC is paying to us, and being applied to our commitment // transaction. So we need to use the receiver's version of HTLC the // script. - case isIncoming && ourCommit: + case isIncoming && whoseCommit.IsLocal(): htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], lntypes.Local, + keyRing.RevocationKey, rHash[:], whoseCommit, ) // We're being paid via an HTLC by the remote party, and the HTLC is // being added to their commitment transaction, so we use the sender's // version of the HTLC script. - case isIncoming && !ourCommit: + case isIncoming && whoseCommit.IsRemote(): htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.RemoteHtlcKey, keyRing.LocalHtlcKey, - keyRing.RevocationKey, rHash[:], lntypes.Remote, + keyRing.RevocationKey, rHash[:], whoseCommit, ) // We're sending an HTLC which is being added to our commitment // transaction. Therefore, we need to use the sender's version of the // HTLC script. - case !isIncoming && ourCommit: + case !isIncoming && whoseCommit.IsLocal(): htlcScriptTree, err = input.SenderHTLCScriptTaproot( keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], lntypes.Local, + keyRing.RevocationKey, rHash[:], whoseCommit, ) // Finally, we're paying the remote party via an HTLC, which is being // added to their commitment transaction. Therefore, we use the // receiver's version of the HTLC script. - case !isIncoming && !ourCommit: + case !isIncoming && whoseCommit.IsRemote(): htlcScriptTree, err = input.ReceiverHTLCScriptTaproot( timeout, keyRing.LocalHtlcKey, keyRing.RemoteHtlcKey, - keyRing.RevocationKey, rHash[:], lntypes.Remote, + keyRing.RevocationKey, rHash[:], whoseCommit, ) } @@ -1135,19 +1136,20 @@ func genTaprootHtlcScript(isIncoming, ourCommit bool, timeout uint32, // multiplexer for the various spending paths is returned. The script path that // we need to sign for the remote party (2nd level HTLCs) is also returned // along side the multiplexer. -func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, - timeout uint32, rHash [32]byte, keyRing *CommitmentKeyRing, +func genHtlcScript(chanType channeldb.ChannelType, isIncoming bool, + whoseCommit lntypes.ChannelParty, timeout uint32, rHash [32]byte, + keyRing *CommitmentKeyRing, ) (input.ScriptDescriptor, error) { if !chanType.IsTaproot() { return genSegwitV0HtlcScript( - chanType, isIncoming, ourCommit, timeout, rHash, + chanType, isIncoming, whoseCommit, timeout, rHash, keyRing, ) } return genTaprootHtlcScript( - isIncoming, ourCommit, timeout, rHash, keyRing, + isIncoming, whoseCommit, timeout, rHash, keyRing, ) } @@ -1158,7 +1160,7 @@ func genHtlcScript(chanType channeldb.ChannelType, isIncoming, ourCommit bool, // locate the added HTLC on the commitment transaction from the // PaymentDescriptor that generated it, the generated script is stored within // the descriptor itself. -func addHTLC(commitTx *wire.MsgTx, ourCommit bool, +func addHTLC(commitTx *wire.MsgTx, whoseCommit lntypes.ChannelParty, isIncoming bool, paymentDesc *PaymentDescriptor, keyRing *CommitmentKeyRing, chanType channeldb.ChannelType) error { @@ -1166,7 +1168,7 @@ func addHTLC(commitTx *wire.MsgTx, ourCommit bool, rHash := paymentDesc.RHash scriptInfo, err := genHtlcScript( - chanType, isIncoming, ourCommit, timeout, rHash, keyRing, + chanType, isIncoming, whoseCommit, timeout, rHash, keyRing, ) if err != nil { return err @@ -1180,7 +1182,7 @@ func addHTLC(commitTx *wire.MsgTx, ourCommit bool, // Store the pkScript of this particular PaymentDescriptor so we can // quickly locate it within the commitment transaction later. - if ourCommit { + if whoseCommit.IsLocal() { paymentDesc.ourPkScript = pkScript paymentDesc.ourWitnessScript = scriptInfo.WitnessScriptToSign() @@ -1211,7 +1213,7 @@ func findOutputIndexesFromRemote(revocationPreimage *chainhash.Hash, // With the commitment point generated, we can now derive the king ring // which will be used to generate the output scripts. keyRing := DeriveCommitmentKeys( - commitmentPoint, false, chanState.ChanType, + commitmentPoint, lntypes.Remote, chanState.ChanType, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) diff --git a/lnwallet/wallet.go b/lnwallet/wallet.go index cf61606da..a56bf1c21 100644 --- a/lnwallet/wallet.go +++ b/lnwallet/wallet.go @@ -25,6 +25,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chanfunding" "github.com/lightningnetwork/lnd/lnwallet/chanvalidate" @@ -1475,10 +1476,12 @@ func CreateCommitmentTxns(localBalance, remoteBalance btcutil.Amount, leaseExpiry uint32) (*wire.MsgTx, *wire.MsgTx, error) { localCommitmentKeys := DeriveCommitmentKeys( - localCommitPoint, true, chanType, ourChanCfg, theirChanCfg, + localCommitPoint, lntypes.Local, chanType, ourChanCfg, + theirChanCfg, ) remoteCommitmentKeys := DeriveCommitmentKeys( - remoteCommitPoint, false, chanType, ourChanCfg, theirChanCfg, + remoteCommitPoint, lntypes.Remote, chanType, ourChanCfg, + theirChanCfg, ) ourCommitTx, err := CreateCommitTx( diff --git a/peer/brontide.go b/peer/brontide.go index 7a390cfd7..81a27bdd3 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -36,6 +36,7 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lnpeer" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -3017,6 +3018,10 @@ func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, maxFee = req.MaxFee } + closer := lntypes.Remote + if locallyInitiated { + closer = lntypes.Local + } chanCloser := chancloser.NewChanCloser( chancloser.ChanCloseCfg{ Channel: channel, @@ -3039,7 +3044,7 @@ func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, fee, uint32(startingHeight), req, - locallyInitiated, + closer, ) return chanCloser, nil From 1d65f5bd120f7dea443a9a749cf322e9948baebb Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 30 Jul 2024 17:03:47 -0700 Subject: [PATCH 5/6] peer: refactor createChanCloser to use ChannelParty --- peer/brontide.go | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/peer/brontide.go b/peer/brontide.go index 81a27bdd3..27438daa6 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -1070,7 +1070,7 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( chanCloser, err := p.createChanCloser( lnChan, info.DeliveryScript.Val, feePerKw, nil, - info.LocalInitiator.Val, + info.Closer(), ) if err != nil { shutdownInfoErr = fmt.Errorf("unable to "+ @@ -2733,7 +2733,7 @@ func (p *Brontide) fetchActiveChanCloser(chanID lnwire.ChannelID) ( } chanCloser, err = p.createChanCloser( - channel, deliveryScript, feePerKw, nil, false, + channel, deliveryScript, feePerKw, nil, lntypes.Remote, ) if err != nil { p.log.Errorf("unable to create chan closer: %v", err) @@ -2970,12 +2970,13 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) ( // Determine whether we or the peer are the initiator of the coop // close attempt by looking at the channel's status. - locallyInitiated := c.HasChanStatus( - channeldb.ChanStatusLocalCloseInitiator, - ) + closingParty := lntypes.Remote + if c.HasChanStatus(channeldb.ChanStatusLocalCloseInitiator) { + closingParty = lntypes.Local + } chanCloser, err := p.createChanCloser( - lnChan, deliveryScript, feePerKw, nil, locallyInitiated, + lnChan, deliveryScript, feePerKw, nil, closingParty, ) if err != nil { p.log.Errorf("unable to create chan closer: %v", err) @@ -3004,7 +3005,7 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) ( func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, deliveryScript lnwire.DeliveryAddress, fee chainfee.SatPerKWeight, req *htlcswitch.ChanClose, - locallyInitiated bool) (*chancloser.ChanCloser, error) { + closer lntypes.ChannelParty) (*chancloser.ChanCloser, error) { _, startingHeight, err := p.cfg.ChainIO.GetBestBlock() if err != nil { @@ -3018,10 +3019,6 @@ func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, maxFee = req.MaxFee } - closer := lntypes.Remote - if locallyInitiated { - closer = lntypes.Local - } chanCloser := chancloser.NewChanCloser( chancloser.ChanCloseCfg{ Channel: channel, @@ -3101,7 +3098,8 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) { } chanCloser, err := p.createChanCloser( - channel, deliveryScript, req.TargetFeePerKw, req, true, + channel, deliveryScript, req.TargetFeePerKw, req, + lntypes.Local, ) if err != nil { p.log.Errorf(err.Error()) From 1f9cac5f809f8d6470d76dc3481ba17cd7b2b0d6 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 30 Jul 2024 17:05:04 -0700 Subject: [PATCH 6/6] htlcswitch: refactor dust handling to use ChannelParty --- htlcswitch/interfaces.go | 2 +- htlcswitch/link.go | 46 +++++++++++++++++++-------------------- htlcswitch/mailbox.go | 7 ++++-- htlcswitch/mock.go | 2 +- htlcswitch/switch.go | 12 ++++++---- htlcswitch/switch_test.go | 22 ++++++++++++++----- 6 files changed, 54 insertions(+), 37 deletions(-) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index a55cd5d0b..1311373a1 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -63,7 +63,7 @@ type dustHandler interface { // getDustSum returns the dust sum on either the local or remote // commitment. An optional fee parameter can be passed in which is used // to calculate the dust sum. - getDustSum(remote bool, + getDustSum(whoseCommit lntypes.ChannelParty, fee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi // getFeeRate returns the current channel feerate. diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 99df2c2b9..eaeaf2e87 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2727,15 +2727,10 @@ func (l *channelLink) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error { // method. // // NOTE: Part of the dustHandler interface. -func (l *channelLink) getDustSum(remote bool, +func (l *channelLink) getDustSum(whoseCommit lntypes.ChannelParty, dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { - party := lntypes.Local - if remote { - party = lntypes.Remote - } - - return l.channel.GetDustSum(party, dryRunFee) + return l.channel.GetDustSum(whoseCommit, dryRunFee) } // getFeeRate is a wrapper method that retrieves the underlying channel's @@ -2789,8 +2784,8 @@ func (l *channelLink) exceedsFeeExposureLimit( // Get the sum of dust for both the local and remote commitments using // this "dry-run" fee. - localDustSum := l.getDustSum(false, dryRunFee) - remoteDustSum := l.getDustSum(true, dryRunFee) + localDustSum := l.getDustSum(lntypes.Local, dryRunFee) + remoteDustSum := l.getDustSum(lntypes.Remote, dryRunFee) // Calculate the local and remote commitment fees using this dry-run // fee. @@ -2831,12 +2826,16 @@ func (l *channelLink) isOverexposedWithHtlc(htlc *lnwire.UpdateAddHTLC, amount := htlc.Amount.ToSatoshis() // See if this HTLC is dust on both the local and remote commitments. - isLocalDust := dustClosure(feeRate, incoming, true, amount) - isRemoteDust := dustClosure(feeRate, incoming, false, amount) + isLocalDust := dustClosure(feeRate, incoming, lntypes.Local, amount) + isRemoteDust := dustClosure(feeRate, incoming, lntypes.Remote, amount) // Calculate the dust sum for the local and remote commitments. - localDustSum := l.getDustSum(false, fn.None[chainfee.SatPerKWeight]()) - remoteDustSum := l.getDustSum(true, fn.None[chainfee.SatPerKWeight]()) + localDustSum := l.getDustSum( + lntypes.Local, fn.None[chainfee.SatPerKWeight](), + ) + remoteDustSum := l.getDustSum( + lntypes.Remote, fn.None[chainfee.SatPerKWeight](), + ) // Grab the larger of the local and remote commitment fees w/o dust. commitFee := l.getCommitFee(false) @@ -2887,25 +2886,26 @@ func (l *channelLink) isOverexposedWithHtlc(htlc *lnwire.UpdateAddHTLC, // the HTLC is incoming (i.e. one that the remote sent), a boolean denoting // whether to evaluate on the local or remote commit, and finally an HTLC // amount to test. -type dustClosure func(chainfee.SatPerKWeight, bool, bool, btcutil.Amount) bool +type dustClosure func(feerate chainfee.SatPerKWeight, incoming bool, + whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool // dustHelper is used to construct the dustClosure. func dustHelper(chantype channeldb.ChannelType, localDustLimit, remoteDustLimit btcutil.Amount) dustClosure { - isDust := func(feerate chainfee.SatPerKWeight, incoming, - localCommit bool, amt btcutil.Amount) bool { + isDust := func(feerate chainfee.SatPerKWeight, incoming bool, + whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool { - if localCommit { - return lnwallet.HtlcIsDust( - chantype, incoming, lntypes.Local, feerate, amt, - localDustLimit, - ) + var dustLimit btcutil.Amount + if whoseCommit.IsLocal() { + dustLimit = localDustLimit + } else { + dustLimit = remoteDustLimit } return lnwallet.HtlcIsDust( - chantype, incoming, lntypes.Remote, feerate, amt, - remoteDustLimit, + chantype, incoming, whoseCommit, feerate, amt, + dustLimit, ) } diff --git a/htlcswitch/mailbox.go b/htlcswitch/mailbox.go index a729e3ba5..9b82f8912 100644 --- a/htlcswitch/mailbox.go +++ b/htlcswitch/mailbox.go @@ -9,6 +9,7 @@ import ( "time" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" ) @@ -660,7 +661,8 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi, // Evaluate whether this HTLC is dust on the local commitment. if m.isDust( - m.feeRate, false, true, addPkt.amount.ToSatoshis(), + m.feeRate, false, lntypes.Local, + addPkt.amount.ToSatoshis(), ) { localDustSum += addPkt.amount @@ -668,7 +670,8 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi, // Evaluate whether this HTLC is dust on the remote commitment. if m.isDust( - m.feeRate, false, false, addPkt.amount.ToSatoshis(), + m.feeRate, false, lntypes.Remote, + addPkt.amount.ToSatoshis(), ) { remoteDustSum += addPkt.amount diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 07efd28a0..96417d9c0 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -814,7 +814,7 @@ func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error { return nil } -func (f *mockChannelLink) getDustSum(remote bool, +func (f *mockChannelLink) getDustSum(whoseCommit lntypes.ChannelParty, dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { return 0 diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index bfca92a3a..793da57db 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2788,8 +2788,12 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink, isDust := link.getDustClosure() // Evaluate if the HTLC is dust on either sides' commitment. - isLocalDust := isDust(feeRate, incoming, true, amount.ToSatoshis()) - isRemoteDust := isDust(feeRate, incoming, false, amount.ToSatoshis()) + isLocalDust := isDust( + feeRate, incoming, lntypes.Local, amount.ToSatoshis(), + ) + isRemoteDust := isDust( + feeRate, incoming, lntypes.Remote, amount.ToSatoshis(), + ) if !(isLocalDust || isRemoteDust) { // If the HTLC is not dust on either commitment, it's fine to @@ -2807,7 +2811,7 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink, // sum for it. if isLocalDust { localSum := link.getDustSum( - false, fn.None[chainfee.SatPerKWeight](), + lntypes.Local, fn.None[chainfee.SatPerKWeight](), ) localSum += localMailDust @@ -2827,7 +2831,7 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink, // reached this point. if isRemoteDust { remoteSum := link.getDustSum( - true, fn.None[chainfee.SatPerKWeight](), + lntypes.Remote, fn.None[chainfee.SatPerKWeight](), ) remoteSum += remoteMailDust diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index ce00cd878..0bc0df2d4 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -4319,7 +4319,7 @@ func TestSwitchDustForwarding(t *testing.T) { } checkAlmostDust := func(link *channelLink, mbox MailBox, - remote bool) bool { + whoseCommit lntypes.ChannelParty) bool { timeout := time.After(15 * time.Second) pollInterval := 300 * time.Millisecond @@ -4335,12 +4335,12 @@ func TestSwitchDustForwarding(t *testing.T) { } linkDust := link.getDustSum( - remote, fn.None[chainfee.SatPerKWeight](), + whoseCommit, fn.None[chainfee.SatPerKWeight](), ) localMailDust, remoteMailDust := mbox.DustPackets() totalDust := linkDust - if remote { + if whoseCommit.IsRemote() { totalDust += remoteMailDust } else { totalDust += localMailDust @@ -4359,7 +4359,11 @@ func TestSwitchDustForwarding(t *testing.T) { n.firstBobChannelLink.ChanID(), n.firstBobChannelLink.ShortChanID(), ) - require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false)) + require.True( + t, checkAlmostDust( + n.firstBobChannelLink, bobMbox, lntypes.Local, + ), + ) // Sending one more HTLC should fail. SendHTLC won't error, but the // HTLC should be failed backwards. @@ -4408,7 +4412,9 @@ func TestSwitchDustForwarding(t *testing.T) { aliceBobFirstHop, uint64(bobAttemptID), nondustHtlc, ) require.NoError(t, err) - require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false)) + require.True(t, checkAlmostDust( + n.firstBobChannelLink, bobMbox, lntypes.Local, + )) // Check that the HTLC failed. bobResultChan, err = n.bobServer.htlcSwitch.GetAttemptResult( @@ -4486,7 +4492,11 @@ func TestSwitchDustForwarding(t *testing.T) { aliceMbox := aliceOrch.GetOrCreateMailBox( n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(), ) - require.True(t, checkAlmostDust(n.aliceChannelLink, aliceMbox, true)) + require.True( + t, checkAlmostDust( + n.aliceChannelLink, aliceMbox, lntypes.Remote, + ), + ) err = n.aliceServer.htlcSwitch.SendHTLC( n.aliceChannelLink.ShortChanID(), uint64(aliceAttemptID),