lnwallet: pack htlcView.{OurUpdates|TheirUpdates} into Dual.

This commit moves the collection of updates behind a Dual structure.
This allows us in a later commit to index into it via a ChannelParty
parameter which will simplify the loops in evaluateHTLCView.
This commit is contained in:
Keagan McClelland 2024-07-24 14:57:32 -07:00
parent 1b2cb14254
commit 4b2a4e36ad
No known key found for this signature in database
GPG Key ID: FA7E65C951F12439
3 changed files with 44 additions and 38 deletions

View File

@ -2675,11 +2675,8 @@ type HtlcView struct {
// created using this view.
NextHeight uint64
// OurUpdates are our outgoing HTLCs.
OurUpdates []*paymentDescriptor
// TheirUpdates are their incoming HTLCs.
TheirUpdates []*paymentDescriptor
// Updates is a Dual of the Local and Remote HTLCs.
Updates lntypes.Dual[[]*paymentDescriptor]
// FeePerKw is the fee rate in sat/kw of the commitment transaction.
FeePerKw chainfee.SatPerKWeight
@ -2688,13 +2685,13 @@ type HtlcView struct {
// AuxOurUpdates returns the outgoing HTLCs as a read-only copy of
// AuxHtlcDescriptors.
func (v *HtlcView) AuxOurUpdates() []AuxHtlcDescriptor {
return fn.Map(newAuxHtlcDescriptor, v.OurUpdates)
return fn.Map(newAuxHtlcDescriptor, v.Updates.Local)
}
// AuxTheirUpdates returns the incoming HTLCs as a read-only copy of
// AuxHtlcDescriptors.
func (v *HtlcView) AuxTheirUpdates() []AuxHtlcDescriptor {
return fn.Map(newAuxHtlcDescriptor, v.TheirUpdates)
return fn.Map(newAuxHtlcDescriptor, v.Updates.Remote)
}
// fetchHTLCView returns all the candidate HTLC updates which should be
@ -2728,8 +2725,10 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex,
}
return &HtlcView{
OurUpdates: ourHTLCs,
TheirUpdates: theirHTLCs,
Updates: lntypes.Dual[[]*paymentDescriptor]{
Local: ourHTLCs,
Remote: theirHTLCs,
},
}
}
@ -2853,15 +2852,15 @@ func (lc *LightningChannel) fetchCommitmentView(
// commitment are mutated, we'll manually copy over each HTLC to its
// respective slice.
c.outgoingHTLCs = make(
[]paymentDescriptor, len(filteredHTLCView.OurUpdates),
[]paymentDescriptor, len(filteredHTLCView.Updates.Local),
)
for i, htlc := range filteredHTLCView.OurUpdates {
for i, htlc := range filteredHTLCView.Updates.Local {
c.outgoingHTLCs[i] = *htlc
}
c.incomingHTLCs = make(
[]paymentDescriptor, len(filteredHTLCView.TheirUpdates),
[]paymentDescriptor, len(filteredHTLCView.Updates.Remote),
)
for i, htlc := range filteredHTLCView.TheirUpdates {
for i, htlc := range filteredHTLCView.Updates.Remote {
c.incomingHTLCs[i] = *htlc
}
@ -2916,7 +2915,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
// First we run through non-add entries in both logs, populating the
// skip sets.
for _, entry := range view.OurUpdates {
for _, entry := range view.Updates.Local {
switch entry.EntryType {
// Skip adds for now. They will be processed below.
case Add:
@ -2961,7 +2960,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
}
// Do the same for our peer's updates.
for _, entry := range view.TheirUpdates {
for _, entry := range view.Updates.Remote {
switch entry.EntryType {
// Skip adds for now. They will be processed below.
case Add:
@ -3007,7 +3006,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
// Next we take a second pass through all the log entries, skipping any
// settled HTLCs, and debiting the chain state balance due to any newly
// added HTLCs.
for _, entry := range view.OurUpdates {
for _, entry := range view.Updates.Local {
isAdd := entry.EntryType == Add
if skipUs.Contains(entry.HtlcIndex) || !isAdd {
continue
@ -3024,11 +3023,11 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
)
}
newView.OurUpdates = append(newView.OurUpdates, entry)
newView.Updates.Local = append(newView.Updates.Local, entry)
}
// Again, we do the same for our peer's updates.
for _, entry := range view.TheirUpdates {
for _, entry := range view.Updates.Remote {
isAdd := entry.EntryType == Add
if skipThem.Contains(entry.HtlcIndex) || !isAdd {
continue
@ -3045,7 +3044,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
)
}
newView.TheirUpdates = append(newView.TheirUpdates, entry)
newView.Updates.Remote = append(newView.Updates.Remote, entry)
}
// Create a function that is capable of identifying whether or not the
@ -3075,10 +3074,12 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance,
// Collect all of the updates that haven't had their commit heights set
// for the commitment chain corresponding to whoseCommitmentChain.
uncommittedUpdates := lntypes.Dual[[]*paymentDescriptor]{
Local: fn.Filter(isUncommitted, view.OurUpdates),
Remote: fn.Filter(isUncommitted, view.TheirUpdates),
}
uncommittedUpdates := lntypes.MapDual(
view.Updates,
func(us []*paymentDescriptor) []*paymentDescriptor {
return fn.Filter(isUncommitted, us)
},
)
return newView, uncommittedUpdates, nil
}
@ -3746,10 +3747,12 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter,
// appropriate update log, in order to validate the sanity of the
// commitment resulting from _actually adding_ this HTLC to the state.
if predictOurAdd != nil {
view.OurUpdates = append(view.OurUpdates, predictOurAdd)
view.Updates.Local = append(view.Updates.Local, predictOurAdd)
}
if predictTheirAdd != nil {
view.TheirUpdates = append(view.TheirUpdates, predictTheirAdd)
view.Updates.Remote = append(
view.Updates.Remote, predictTheirAdd,
)
}
ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView(
@ -3904,7 +3907,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter,
// First check that the remote updates won't violate it's channel
// constraints.
err = validateUpdates(
filteredView.TheirUpdates, &lc.channelState.RemoteChanCfg,
filteredView.Updates.Remote, &lc.channelState.RemoteChanCfg,
)
if err != nil {
return err
@ -3913,7 +3916,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter,
// Secondly check that our updates won't violate our channel
// constraints.
err = validateUpdates(
filteredView.OurUpdates, &lc.channelState.LocalChanCfg,
filteredView.Updates.Local, &lc.channelState.LocalChanCfg,
)
if err != nil {
return err
@ -4700,7 +4703,7 @@ func (lc *LightningChannel) computeView(view *HtlcView,
// Now go through all HTLCs at this stage, to calculate the total
// weight, needed to calculate the transaction fee.
var totalHtlcWeight lntypes.WeightUnit
for _, htlc := range filteredHTLCView.OurUpdates {
for _, htlc := range filteredHTLCView.Updates.Local {
if HtlcIsDust(
lc.channelState.ChanType, false, whoseCommitChain,
feePerKw, htlc.Amount.ToSatoshis(), dustLimit,
@ -4711,7 +4714,7 @@ func (lc *LightningChannel) computeView(view *HtlcView,
totalHtlcWeight += input.HTLCWeight
}
for _, htlc := range filteredHTLCView.TheirUpdates {
for _, htlc := range filteredHTLCView.Updates.Remote {
if HtlcIsDust(
lc.channelState.ChanType, true, whoseCommitChain,
feePerKw, htlc.Amount.ToSatoshis(), dustLimit,

View File

@ -8941,9 +8941,11 @@ func TestEvaluateView(t *testing.T) {
}
view := &HtlcView{
OurUpdates: test.ourHtlcs,
TheirUpdates: test.theirHtlcs,
FeePerKw: feePerKw,
Updates: lntypes.Dual[[]*paymentDescriptor]{
Local: test.ourHtlcs,
Remote: test.theirHtlcs,
},
FeePerKw: feePerKw,
}
var (
@ -8996,11 +8998,12 @@ func TestEvaluateView(t *testing.T) {
}
checkExpectedHtlcs(
t, result.OurUpdates, test.ourExpectedHtlcs,
t, result.Updates.Local, test.ourExpectedHtlcs,
)
checkExpectedHtlcs(
t, result.TheirUpdates, test.theirExpectedHtlcs,
t, result.Updates.Remote,
test.theirExpectedHtlcs,
)
if lc.channelState.TotalMSatSent != test.expectSent {

View File

@ -702,7 +702,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
}
numHTLCs := int64(0)
for _, htlc := range filteredHTLCView.OurUpdates {
for _, htlc := range filteredHTLCView.Updates.Local {
if HtlcIsDust(
cb.chanState.ChanType, false, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,
@ -713,7 +713,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
numHTLCs++
}
for _, htlc := range filteredHTLCView.TheirUpdates {
for _, htlc := range filteredHTLCView.Updates.Remote {
if HtlcIsDust(
cb.chanState.ChanType, true, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,
@ -827,7 +827,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
// purposes of sorting.
cltvs := make([]uint32, len(commitTx.TxOut))
htlcIndexes := make([]input.HtlcIndex, len(commitTx.TxOut))
for _, htlc := range filteredHTLCView.OurUpdates {
for _, htlc := range filteredHTLCView.Updates.Local {
if HtlcIsDust(
cb.chanState.ChanType, false, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,
@ -855,7 +855,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance,
cltvs = append(cltvs, htlc.Timeout) //nolint
htlcIndexes = append(htlcIndexes, htlc.HtlcIndex) //nolint
}
for _, htlc := range filteredHTLCView.TheirUpdates {
for _, htlc := range filteredHTLCView.Updates.Remote {
if HtlcIsDust(
cb.chanState.ChanType, true, whoseCommit, feePerKw,
htlc.Amount.ToSatoshis(), dustLimit,