diff --git a/lnwallet/channel.go b/lnwallet/channel.go index f09d84813..b66b14780 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2675,11 +2675,8 @@ type HtlcView struct { // created using this view. NextHeight uint64 - // OurUpdates are our outgoing HTLCs. - OurUpdates []*paymentDescriptor - - // TheirUpdates are their incoming HTLCs. - TheirUpdates []*paymentDescriptor + // Updates is a Dual of the Local and Remote HTLCs. + Updates lntypes.Dual[[]*paymentDescriptor] // FeePerKw is the fee rate in sat/kw of the commitment transaction. FeePerKw chainfee.SatPerKWeight @@ -2688,13 +2685,13 @@ type HtlcView struct { // AuxOurUpdates returns the outgoing HTLCs as a read-only copy of // AuxHtlcDescriptors. func (v *HtlcView) AuxOurUpdates() []AuxHtlcDescriptor { - return fn.Map(newAuxHtlcDescriptor, v.OurUpdates) + return fn.Map(newAuxHtlcDescriptor, v.Updates.Local) } // AuxTheirUpdates returns the incoming HTLCs as a read-only copy of // AuxHtlcDescriptors. func (v *HtlcView) AuxTheirUpdates() []AuxHtlcDescriptor { - return fn.Map(newAuxHtlcDescriptor, v.TheirUpdates) + return fn.Map(newAuxHtlcDescriptor, v.Updates.Remote) } // fetchHTLCView returns all the candidate HTLC updates which should be @@ -2728,8 +2725,10 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, } return &HtlcView{ - OurUpdates: ourHTLCs, - TheirUpdates: theirHTLCs, + Updates: lntypes.Dual[[]*paymentDescriptor]{ + Local: ourHTLCs, + Remote: theirHTLCs, + }, } } @@ -2853,15 +2852,15 @@ func (lc *LightningChannel) fetchCommitmentView( // commitment are mutated, we'll manually copy over each HTLC to its // respective slice. c.outgoingHTLCs = make( - []paymentDescriptor, len(filteredHTLCView.OurUpdates), + []paymentDescriptor, len(filteredHTLCView.Updates.Local), ) - for i, htlc := range filteredHTLCView.OurUpdates { + for i, htlc := range filteredHTLCView.Updates.Local { c.outgoingHTLCs[i] = *htlc } c.incomingHTLCs = make( - []paymentDescriptor, len(filteredHTLCView.TheirUpdates), + []paymentDescriptor, len(filteredHTLCView.Updates.Remote), ) - for i, htlc := range filteredHTLCView.TheirUpdates { + for i, htlc := range filteredHTLCView.Updates.Remote { c.incomingHTLCs[i] = *htlc } @@ -2916,7 +2915,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // First we run through non-add entries in both logs, populating the // skip sets. - for _, entry := range view.OurUpdates { + for _, entry := range view.Updates.Local { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -2961,7 +2960,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, } // Do the same for our peer's updates. - for _, entry := range view.TheirUpdates { + for _, entry := range view.Updates.Remote { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -3007,7 +3006,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Next we take a second pass through all the log entries, skipping any // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. - for _, entry := range view.OurUpdates { + for _, entry := range view.Updates.Local { isAdd := entry.EntryType == Add if skipUs.Contains(entry.HtlcIndex) || !isAdd { continue @@ -3024,11 +3023,11 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) } - newView.OurUpdates = append(newView.OurUpdates, entry) + newView.Updates.Local = append(newView.Updates.Local, entry) } // Again, we do the same for our peer's updates. - for _, entry := range view.TheirUpdates { + for _, entry := range view.Updates.Remote { isAdd := entry.EntryType == Add if skipThem.Contains(entry.HtlcIndex) || !isAdd { continue @@ -3045,7 +3044,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) } - newView.TheirUpdates = append(newView.TheirUpdates, entry) + newView.Updates.Remote = append(newView.Updates.Remote, entry) } // Create a function that is capable of identifying whether or not the @@ -3075,10 +3074,12 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Collect all of the updates that haven't had their commit heights set // for the commitment chain corresponding to whoseCommitmentChain. - uncommittedUpdates := lntypes.Dual[[]*paymentDescriptor]{ - Local: fn.Filter(isUncommitted, view.OurUpdates), - Remote: fn.Filter(isUncommitted, view.TheirUpdates), - } + uncommittedUpdates := lntypes.MapDual( + view.Updates, + func(us []*paymentDescriptor) []*paymentDescriptor { + return fn.Filter(isUncommitted, us) + }, + ) return newView, uncommittedUpdates, nil } @@ -3746,10 +3747,12 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // appropriate update log, in order to validate the sanity of the // commitment resulting from _actually adding_ this HTLC to the state. if predictOurAdd != nil { - view.OurUpdates = append(view.OurUpdates, predictOurAdd) + view.Updates.Local = append(view.Updates.Local, predictOurAdd) } if predictTheirAdd != nil { - view.TheirUpdates = append(view.TheirUpdates, predictTheirAdd) + view.Updates.Remote = append( + view.Updates.Remote, predictTheirAdd, + ) } ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( @@ -3904,7 +3907,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // First check that the remote updates won't violate it's channel // constraints. err = validateUpdates( - filteredView.TheirUpdates, &lc.channelState.RemoteChanCfg, + filteredView.Updates.Remote, &lc.channelState.RemoteChanCfg, ) if err != nil { return err @@ -3913,7 +3916,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // Secondly check that our updates won't violate our channel // constraints. err = validateUpdates( - filteredView.OurUpdates, &lc.channelState.LocalChanCfg, + filteredView.Updates.Local, &lc.channelState.LocalChanCfg, ) if err != nil { return err @@ -4700,7 +4703,7 @@ func (lc *LightningChannel) computeView(view *HtlcView, // Now go through all HTLCs at this stage, to calculate the total // weight, needed to calculate the transaction fee. var totalHtlcWeight lntypes.WeightUnit - for _, htlc := range filteredHTLCView.OurUpdates { + for _, htlc := range filteredHTLCView.Updates.Local { if HtlcIsDust( lc.channelState.ChanType, false, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -4711,7 +4714,7 @@ func (lc *LightningChannel) computeView(view *HtlcView, totalHtlcWeight += input.HTLCWeight } - for _, htlc := range filteredHTLCView.TheirUpdates { + for _, htlc := range filteredHTLCView.Updates.Remote { if HtlcIsDust( lc.channelState.ChanType, true, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 0bac824a2..dd352dab1 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -8941,9 +8941,11 @@ func TestEvaluateView(t *testing.T) { } view := &HtlcView{ - OurUpdates: test.ourHtlcs, - TheirUpdates: test.theirHtlcs, - FeePerKw: feePerKw, + Updates: lntypes.Dual[[]*paymentDescriptor]{ + Local: test.ourHtlcs, + Remote: test.theirHtlcs, + }, + FeePerKw: feePerKw, } var ( @@ -8996,11 +8998,12 @@ func TestEvaluateView(t *testing.T) { } checkExpectedHtlcs( - t, result.OurUpdates, test.ourExpectedHtlcs, + t, result.Updates.Local, test.ourExpectedHtlcs, ) checkExpectedHtlcs( - t, result.TheirUpdates, test.theirExpectedHtlcs, + t, result.Updates.Remote, + test.theirExpectedHtlcs, ) if lc.channelState.TotalMSatSent != test.expectSent { diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 170efece1..6d61729a4 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -702,7 +702,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } numHTLCs := int64(0) - for _, htlc := range filteredHTLCView.OurUpdates { + for _, htlc := range filteredHTLCView.Updates.Local { if HtlcIsDust( cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -713,7 +713,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, numHTLCs++ } - for _, htlc := range filteredHTLCView.TheirUpdates { + for _, htlc := range filteredHTLCView.Updates.Remote { if HtlcIsDust( cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -827,7 +827,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, // purposes of sorting. cltvs := make([]uint32, len(commitTx.TxOut)) htlcIndexes := make([]input.HtlcIndex, len(commitTx.TxOut)) - for _, htlc := range filteredHTLCView.OurUpdates { + for _, htlc := range filteredHTLCView.Updates.Local { if HtlcIsDust( cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -855,7 +855,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, cltvs = append(cltvs, htlc.Timeout) //nolint htlcIndexes = append(htlcIndexes, htlc.HtlcIndex) //nolint } - for _, htlc := range filteredHTLCView.TheirUpdates { + for _, htlc := range filteredHTLCView.Updates.Remote { if HtlcIsDust( cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit,