htlcswitch: refactor dust handling to use ChannelParty

This commit is contained in:
Keagan McClelland 2024-07-30 17:05:04 -07:00
parent 1d65f5bd12
commit 1f9cac5f80
No known key found for this signature in database
GPG Key ID: FA7E65C951F12439
6 changed files with 54 additions and 37 deletions

View File

@ -63,7 +63,7 @@ type dustHandler interface {
// getDustSum returns the dust sum on either the local or remote // getDustSum returns the dust sum on either the local or remote
// commitment. An optional fee parameter can be passed in which is used // commitment. An optional fee parameter can be passed in which is used
// to calculate the dust sum. // to calculate the dust sum.
getDustSum(remote bool, getDustSum(whoseCommit lntypes.ChannelParty,
fee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi fee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi
// getFeeRate returns the current channel feerate. // getFeeRate returns the current channel feerate.

View File

@ -2727,15 +2727,10 @@ func (l *channelLink) MayAddOutgoingHtlc(amt lnwire.MilliSatoshi) error {
// method. // method.
// //
// NOTE: Part of the dustHandler interface. // NOTE: Part of the dustHandler interface.
func (l *channelLink) getDustSum(remote bool, func (l *channelLink) getDustSum(whoseCommit lntypes.ChannelParty,
dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi {
party := lntypes.Local return l.channel.GetDustSum(whoseCommit, dryRunFee)
if remote {
party = lntypes.Remote
}
return l.channel.GetDustSum(party, dryRunFee)
} }
// getFeeRate is a wrapper method that retrieves the underlying channel's // getFeeRate is a wrapper method that retrieves the underlying channel's
@ -2789,8 +2784,8 @@ func (l *channelLink) exceedsFeeExposureLimit(
// Get the sum of dust for both the local and remote commitments using // Get the sum of dust for both the local and remote commitments using
// this "dry-run" fee. // this "dry-run" fee.
localDustSum := l.getDustSum(false, dryRunFee) localDustSum := l.getDustSum(lntypes.Local, dryRunFee)
remoteDustSum := l.getDustSum(true, dryRunFee) remoteDustSum := l.getDustSum(lntypes.Remote, dryRunFee)
// Calculate the local and remote commitment fees using this dry-run // Calculate the local and remote commitment fees using this dry-run
// fee. // fee.
@ -2831,12 +2826,16 @@ func (l *channelLink) isOverexposedWithHtlc(htlc *lnwire.UpdateAddHTLC,
amount := htlc.Amount.ToSatoshis() amount := htlc.Amount.ToSatoshis()
// See if this HTLC is dust on both the local and remote commitments. // See if this HTLC is dust on both the local and remote commitments.
isLocalDust := dustClosure(feeRate, incoming, true, amount) isLocalDust := dustClosure(feeRate, incoming, lntypes.Local, amount)
isRemoteDust := dustClosure(feeRate, incoming, false, amount) isRemoteDust := dustClosure(feeRate, incoming, lntypes.Remote, amount)
// Calculate the dust sum for the local and remote commitments. // Calculate the dust sum for the local and remote commitments.
localDustSum := l.getDustSum(false, fn.None[chainfee.SatPerKWeight]()) localDustSum := l.getDustSum(
remoteDustSum := l.getDustSum(true, fn.None[chainfee.SatPerKWeight]()) lntypes.Local, fn.None[chainfee.SatPerKWeight](),
)
remoteDustSum := l.getDustSum(
lntypes.Remote, fn.None[chainfee.SatPerKWeight](),
)
// Grab the larger of the local and remote commitment fees w/o dust. // Grab the larger of the local and remote commitment fees w/o dust.
commitFee := l.getCommitFee(false) commitFee := l.getCommitFee(false)
@ -2887,25 +2886,26 @@ func (l *channelLink) isOverexposedWithHtlc(htlc *lnwire.UpdateAddHTLC,
// the HTLC is incoming (i.e. one that the remote sent), a boolean denoting // the HTLC is incoming (i.e. one that the remote sent), a boolean denoting
// whether to evaluate on the local or remote commit, and finally an HTLC // whether to evaluate on the local or remote commit, and finally an HTLC
// amount to test. // amount to test.
type dustClosure func(chainfee.SatPerKWeight, bool, bool, btcutil.Amount) bool type dustClosure func(feerate chainfee.SatPerKWeight, incoming bool,
whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool
// dustHelper is used to construct the dustClosure. // dustHelper is used to construct the dustClosure.
func dustHelper(chantype channeldb.ChannelType, localDustLimit, func dustHelper(chantype channeldb.ChannelType, localDustLimit,
remoteDustLimit btcutil.Amount) dustClosure { remoteDustLimit btcutil.Amount) dustClosure {
isDust := func(feerate chainfee.SatPerKWeight, incoming, isDust := func(feerate chainfee.SatPerKWeight, incoming bool,
localCommit bool, amt btcutil.Amount) bool { whoseCommit lntypes.ChannelParty, amt btcutil.Amount) bool {
if localCommit { var dustLimit btcutil.Amount
return lnwallet.HtlcIsDust( if whoseCommit.IsLocal() {
chantype, incoming, lntypes.Local, feerate, amt, dustLimit = localDustLimit
localDustLimit, } else {
) dustLimit = remoteDustLimit
} }
return lnwallet.HtlcIsDust( return lnwallet.HtlcIsDust(
chantype, incoming, lntypes.Remote, feerate, amt, chantype, incoming, whoseCommit, feerate, amt,
remoteDustLimit, dustLimit,
) )
} }

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -660,7 +661,8 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi,
// Evaluate whether this HTLC is dust on the local commitment. // Evaluate whether this HTLC is dust on the local commitment.
if m.isDust( if m.isDust(
m.feeRate, false, true, addPkt.amount.ToSatoshis(), m.feeRate, false, lntypes.Local,
addPkt.amount.ToSatoshis(),
) { ) {
localDustSum += addPkt.amount localDustSum += addPkt.amount
@ -668,7 +670,8 @@ func (m *memoryMailBox) DustPackets() (lnwire.MilliSatoshi,
// Evaluate whether this HTLC is dust on the remote commitment. // Evaluate whether this HTLC is dust on the remote commitment.
if m.isDust( if m.isDust(
m.feeRate, false, false, addPkt.amount.ToSatoshis(), m.feeRate, false, lntypes.Remote,
addPkt.amount.ToSatoshis(),
) { ) {
remoteDustSum += addPkt.amount remoteDustSum += addPkt.amount

View File

@ -814,7 +814,7 @@ func (f *mockChannelLink) handleSwitchPacket(pkt *htlcPacket) error {
return nil return nil
} }
func (f *mockChannelLink) getDustSum(remote bool, func (f *mockChannelLink) getDustSum(whoseCommit lntypes.ChannelParty,
dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi { dryRunFee fn.Option[chainfee.SatPerKWeight]) lnwire.MilliSatoshi {
return 0 return 0

View File

@ -2788,8 +2788,12 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink,
isDust := link.getDustClosure() isDust := link.getDustClosure()
// Evaluate if the HTLC is dust on either sides' commitment. // Evaluate if the HTLC is dust on either sides' commitment.
isLocalDust := isDust(feeRate, incoming, true, amount.ToSatoshis()) isLocalDust := isDust(
isRemoteDust := isDust(feeRate, incoming, false, amount.ToSatoshis()) feeRate, incoming, lntypes.Local, amount.ToSatoshis(),
)
isRemoteDust := isDust(
feeRate, incoming, lntypes.Remote, amount.ToSatoshis(),
)
if !(isLocalDust || isRemoteDust) { if !(isLocalDust || isRemoteDust) {
// If the HTLC is not dust on either commitment, it's fine to // If the HTLC is not dust on either commitment, it's fine to
@ -2807,7 +2811,7 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink,
// sum for it. // sum for it.
if isLocalDust { if isLocalDust {
localSum := link.getDustSum( localSum := link.getDustSum(
false, fn.None[chainfee.SatPerKWeight](), lntypes.Local, fn.None[chainfee.SatPerKWeight](),
) )
localSum += localMailDust localSum += localMailDust
@ -2827,7 +2831,7 @@ func (s *Switch) dustExceedsFeeThreshold(link ChannelLink,
// reached this point. // reached this point.
if isRemoteDust { if isRemoteDust {
remoteSum := link.getDustSum( remoteSum := link.getDustSum(
true, fn.None[chainfee.SatPerKWeight](), lntypes.Remote, fn.None[chainfee.SatPerKWeight](),
) )
remoteSum += remoteMailDust remoteSum += remoteMailDust

View File

@ -4319,7 +4319,7 @@ func TestSwitchDustForwarding(t *testing.T) {
} }
checkAlmostDust := func(link *channelLink, mbox MailBox, checkAlmostDust := func(link *channelLink, mbox MailBox,
remote bool) bool { whoseCommit lntypes.ChannelParty) bool {
timeout := time.After(15 * time.Second) timeout := time.After(15 * time.Second)
pollInterval := 300 * time.Millisecond pollInterval := 300 * time.Millisecond
@ -4335,12 +4335,12 @@ func TestSwitchDustForwarding(t *testing.T) {
} }
linkDust := link.getDustSum( linkDust := link.getDustSum(
remote, fn.None[chainfee.SatPerKWeight](), whoseCommit, fn.None[chainfee.SatPerKWeight](),
) )
localMailDust, remoteMailDust := mbox.DustPackets() localMailDust, remoteMailDust := mbox.DustPackets()
totalDust := linkDust totalDust := linkDust
if remote { if whoseCommit.IsRemote() {
totalDust += remoteMailDust totalDust += remoteMailDust
} else { } else {
totalDust += localMailDust totalDust += localMailDust
@ -4359,7 +4359,11 @@ func TestSwitchDustForwarding(t *testing.T) {
n.firstBobChannelLink.ChanID(), n.firstBobChannelLink.ChanID(),
n.firstBobChannelLink.ShortChanID(), n.firstBobChannelLink.ShortChanID(),
) )
require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false)) require.True(
t, checkAlmostDust(
n.firstBobChannelLink, bobMbox, lntypes.Local,
),
)
// Sending one more HTLC should fail. SendHTLC won't error, but the // Sending one more HTLC should fail. SendHTLC won't error, but the
// HTLC should be failed backwards. // HTLC should be failed backwards.
@ -4408,7 +4412,9 @@ func TestSwitchDustForwarding(t *testing.T) {
aliceBobFirstHop, uint64(bobAttemptID), nondustHtlc, aliceBobFirstHop, uint64(bobAttemptID), nondustHtlc,
) )
require.NoError(t, err) require.NoError(t, err)
require.True(t, checkAlmostDust(n.firstBobChannelLink, bobMbox, false)) require.True(t, checkAlmostDust(
n.firstBobChannelLink, bobMbox, lntypes.Local,
))
// Check that the HTLC failed. // Check that the HTLC failed.
bobResultChan, err = n.bobServer.htlcSwitch.GetAttemptResult( bobResultChan, err = n.bobServer.htlcSwitch.GetAttemptResult(
@ -4486,7 +4492,11 @@ func TestSwitchDustForwarding(t *testing.T) {
aliceMbox := aliceOrch.GetOrCreateMailBox( aliceMbox := aliceOrch.GetOrCreateMailBox(
n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(), n.aliceChannelLink.ChanID(), n.aliceChannelLink.ShortChanID(),
) )
require.True(t, checkAlmostDust(n.aliceChannelLink, aliceMbox, true)) require.True(
t, checkAlmostDust(
n.aliceChannelLink, aliceMbox, lntypes.Remote,
),
)
err = n.aliceServer.htlcSwitch.SendHTLC( err = n.aliceServer.htlcSwitch.SendHTLC(
n.aliceChannelLink.ShortChanID(), uint64(aliceAttemptID), n.aliceChannelLink.ShortChanID(), uint64(aliceAttemptID),