From 04c37344ae37a38060a605dcd924e5130ca01a81 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 30 Apr 2024 15:23:50 -0700 Subject: [PATCH] lnwallet: refactor channel to use new typed List --- lnwallet/channel.go | 16 ++++++++-------- lnwallet/channel_test.go | 21 ++++++++++----------- lnwallet/commitment_chain.go | 14 +++++++------- lnwallet/update_log.go | 24 ++++++++++++------------ 4 files changed, 37 insertions(+), 38 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index abe8b6276..1605d04a1 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2474,7 +2474,7 @@ type htlcView struct { func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *htlcView { var ourHTLCs []*PaymentDescriptor for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() { - htlc := e.Value.(*PaymentDescriptor) + htlc := e.Value // This HTLC is active from this point-of-view iff the log // index of the state update is below the specified index in @@ -2486,7 +2486,7 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *ht var theirHTLCs []*PaymentDescriptor for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() { - htlc := e.Value.(*PaymentDescriptor) + htlc := e.Value // If this is an incoming HTLC, then it is only active from // this point-of-view if the index of the HTLC addition in @@ -3112,7 +3112,7 @@ func (lc *LightningChannel) createCommitDiff( // set of items we need to retransmit if we reconnect and find that // they didn't process this new state fully. for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() { - pd := e.Value.(*PaymentDescriptor) + pd := e.Value // If this entry wasn't committed at the exact height of this // remote commitment, then we'll skip it as it was already @@ -3250,7 +3250,7 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate { // remote party expects. var logUpdates []channeldb.LogUpdate for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() { - pd := e.Value.(*PaymentDescriptor) + pd := e.Value // Skip all remote updates that we have already included in our // commit chain. @@ -5195,7 +5195,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( var addIndex, settleFailIndex uint16 for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() { - pd := e.Value.(*PaymentDescriptor) + pd := e.Value // Fee updates are local to this particular channel, and should // never be forwarded. @@ -5525,7 +5525,7 @@ func (lc *LightningChannel) GetDustSum(remote bool, // Grab all of our HTLCs and evaluate against the dust limit. for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() { - pd := e.Value.(*PaymentDescriptor) + pd := e.Value if pd.EntryType != Add { continue } @@ -5544,7 +5544,7 @@ func (lc *LightningChannel) GetDustSum(remote bool, // Grab all of their HTLCs and evaluate against the dust limit. for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() { - pd := e.Value.(*PaymentDescriptor) + pd := e.Value if pd.EntryType != Add { continue } @@ -8545,7 +8545,7 @@ func (lc *LightningChannel) unsignedLocalUpdates(remoteMessageIndex, var localPeerUpdates []channeldb.LogUpdate for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() { - pd := e.Value.(*PaymentDescriptor) + pd := e.Value // We don't save add updates as they are restored from the // remote commitment in restoreStateLogs. diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 185bdf87a..ef996ea41 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -2,7 +2,6 @@ package lnwallet import ( "bytes" - "container/list" "crypto/sha256" "fmt" "math/rand" @@ -1906,7 +1905,7 @@ func TestStateUpdatePersistence(t *testing.T) { // Newly generated pkScripts for HTLCs should be the same as in the old channel. for _, entry := range aliceChannel.localUpdateLog.htlcIndex { - htlc := entry.Value.(*PaymentDescriptor) + htlc := entry.Value restoredHtlc := aliceChannelNew.localUpdateLog.lookupHtlc(htlc.HtlcIndex) if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) { t.Fatalf("alice ourPkScript in ourLog: expected %X, got %X", @@ -1918,7 +1917,7 @@ func TestStateUpdatePersistence(t *testing.T) { } } for _, entry := range aliceChannel.remoteUpdateLog.htlcIndex { - htlc := entry.Value.(*PaymentDescriptor) + htlc := entry.Value restoredHtlc := aliceChannelNew.remoteUpdateLog.lookupHtlc(htlc.HtlcIndex) if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) { t.Fatalf("alice ourPkScript in theirLog: expected %X, got %X", @@ -1930,7 +1929,7 @@ func TestStateUpdatePersistence(t *testing.T) { } } for _, entry := range bobChannel.localUpdateLog.htlcIndex { - htlc := entry.Value.(*PaymentDescriptor) + htlc := entry.Value restoredHtlc := bobChannelNew.localUpdateLog.lookupHtlc(htlc.HtlcIndex) if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) { t.Fatalf("bob ourPkScript in ourLog: expected %X, got %X", @@ -1942,7 +1941,7 @@ func TestStateUpdatePersistence(t *testing.T) { } } for _, entry := range bobChannel.remoteUpdateLog.htlcIndex { - htlc := entry.Value.(*PaymentDescriptor) + htlc := entry.Value restoredHtlc := bobChannelNew.remoteUpdateLog.lookupHtlc(htlc.HtlcIndex) if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) { t.Fatalf("bob ourPkScript in theirLog: expected %X, got %X", @@ -4472,7 +4471,7 @@ func TestFeeUpdateOldDiskFormat(t *testing.T) { countLog := func(log *updateLog) (int, int) { var numUpdates, numFee int for e := log.Front(); e != nil; e = e.Next() { - htlc := e.Value.(*PaymentDescriptor) + htlc := e.Value if htlc.EntryType == FeeUpdate { numFee++ } @@ -6755,14 +6754,14 @@ func compareHtlcs(htlc1, htlc2 *PaymentDescriptor) error { } // compareIndexes is a helper method to compare two index maps. -func compareIndexes(a, b map[uint64]*list.Element) error { +func compareIndexes(a, b map[uint64]*fn.Node[*PaymentDescriptor]) error { for k1, e1 := range a { e2, ok := b[k1] if !ok { return fmt.Errorf("element with key %d "+ "not found in b", k1) } - htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor) + htlc1, htlc2 := e1.Value, e2.Value if err := compareHtlcs(htlc1, htlc2); err != nil { return err } @@ -6774,7 +6773,7 @@ func compareIndexes(a, b map[uint64]*list.Element) error { return fmt.Errorf("element with key %d not "+ "found in a", k1) } - htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor) + htlc1, htlc2 := e1.Value, e2.Value if err := compareHtlcs(htlc1, htlc2); err != nil { return err } @@ -6809,7 +6808,7 @@ func compareLogs(a, b *updateLog) error { e1, e2 := a.Front(), b.Front() for ; e1 != nil; e1, e2 = e1.Next(), e2.Next() { - htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor) + htlc1, htlc2 := e1.Value, e2.Value if err := compareHtlcs(htlc1, htlc2); err != nil { return err } @@ -6917,7 +6916,7 @@ func TestChannelRestoreUpdateLogs(t *testing.T) { func fetchNumUpdates(t updateType, log *updateLog) int { num := 0 for e := log.Front(); e != nil; e = e.Next() { - htlc := e.Value.(*PaymentDescriptor) + htlc := e.Value if htlc.EntryType == t { num++ } diff --git a/lnwallet/commitment_chain.go b/lnwallet/commitment_chain.go index 7894e9e3f..fa2abe0aa 100644 --- a/lnwallet/commitment_chain.go +++ b/lnwallet/commitment_chain.go @@ -1,6 +1,8 @@ package lnwallet -import "container/list" +import ( + "github.com/lightningnetwork/lnd/fn" +) // commitmentChain represents a chain of unrevoked commitments. The tail of the // chain is the latest fully signed, yet unrevoked commitment. Two chains are @@ -15,13 +17,13 @@ type commitmentChain struct { // commitments are added to the end of the chain with increase height. // Once a commitment transaction is revoked, the tail is incremented, // freeing up the revocation window for new commitments. - commitments *list.List + commitments *fn.List[*commitment] } // newCommitmentChain creates a new commitment chain. func newCommitmentChain() *commitmentChain { return &commitmentChain{ - commitments: list.New(), + commitments: fn.NewList[*commitment](), } } @@ -42,14 +44,12 @@ func (s *commitmentChain) advanceTail() { // tip returns the latest commitment added to the chain. func (s *commitmentChain) tip() *commitment { - //nolint:forcetypeassert - return s.commitments.Back().Value.(*commitment) + return s.commitments.Back().Value } // tail returns the lowest unrevoked commitment transaction in the chain. func (s *commitmentChain) tail() *commitment { - //nolint:forcetypeassert - return s.commitments.Front().Value.(*commitment) + return s.commitments.Front().Value } // hasUnackedCommitment returns true if the commitment chain has more than one diff --git a/lnwallet/update_log.go b/lnwallet/update_log.go index 5cb39ef1e..80332cfc5 100644 --- a/lnwallet/update_log.go +++ b/lnwallet/update_log.go @@ -1,6 +1,8 @@ package lnwallet -import "container/list" +import ( + "github.com/lightningnetwork/lnd/fn" +) // updateLog is an append-only log that stores updates to a node's commitment // chain. This structure can be seen as the "mempool" within Lightning where @@ -27,16 +29,16 @@ type updateLog struct { // List is the updatelog itself, we embed this value so updateLog has // access to all the method of a list.List. - *list.List + *fn.List[*PaymentDescriptor] // updateIndex maps a `logIndex` to a particular update entry. It // deals with the four update types: // `Fail|MalformedFail|Settle|FeeUpdate` - updateIndex map[uint64]*list.Element + updateIndex map[uint64]*fn.Node[*PaymentDescriptor] // htlcIndex maps a `htlcCounter` to an offered HTLC entry, hence the // `Add` update. - htlcIndex map[uint64]*list.Element + htlcIndex map[uint64]*fn.Node[*PaymentDescriptor] // modifiedHtlcs is a set that keeps track of all the current modified // htlcs, hence update types `Fail|MalformedFail|Settle`. A modified @@ -48,9 +50,9 @@ type updateLog struct { // newUpdateLog creates a new updateLog instance. func newUpdateLog(logIndex, htlcCounter uint64) *updateLog { return &updateLog{ - List: list.New(), - updateIndex: make(map[uint64]*list.Element), - htlcIndex: make(map[uint64]*list.Element), + List: fn.NewList[*PaymentDescriptor](), + updateIndex: make(map[uint64]*fn.Node[*PaymentDescriptor]), + htlcIndex: make(map[uint64]*fn.Node[*PaymentDescriptor]), logIndex: logIndex, htlcCounter: htlcCounter, modifiedHtlcs: make(map[uint64]struct{}), @@ -101,8 +103,7 @@ func (u *updateLog) lookupHtlc(i uint64) *PaymentDescriptor { return nil } - //nolint:forcetypeassert - return htlc.Value.(*PaymentDescriptor) + return htlc.Value } // remove attempts to remove an entry from the update log. If the entry is @@ -145,15 +146,14 @@ func compactLogs(ourLog, theirLog *updateLog, localChainTail, remoteChainTail uint64) { compactLog := func(logA, logB *updateLog) { - var nextA *list.Element + var nextA *fn.Node[*PaymentDescriptor] for e := logA.Front(); e != nil; e = nextA { // Assign next iteration element at top of loop because // we may remove the current element from the list, // which can change the iterated sequence. nextA = e.Next() - //nolint:forcetypeassert - htlc := e.Value.(*PaymentDescriptor) + htlc := e.Value // We skip Adds, as they will be removed along with the // fail/settles below.