From cf0e0820d689070cd1d8a0176ca2731cb16d42f1 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Tue, 11 Oct 2022 22:26:28 +0800 Subject: [PATCH] lntemp: introduce `SyncMap` to store type information This commit replaces the usage of `sync.Map` with the new struct `SyncMap` to explicitly express the type info used in the map. --- lntemp/node/state.go | 27 ++++++------ lntemp/node/sync_map.go | 53 ++++++++++++++++++++++ lntemp/node/watcher.go | 97 +++++++++++++---------------------------- 3 files changed, 97 insertions(+), 80 deletions(-) create mode 100644 lntemp/node/sync_map.go diff --git a/lntemp/node/state.go b/lntemp/node/state.go index 86dda03cb..9dca10f42 100644 --- a/lntemp/node/state.go +++ b/lntemp/node/state.go @@ -4,9 +4,9 @@ import ( "encoding/json" "fmt" "math" - "sync" "time" + "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc/walletrpc" "github.com/lightningnetwork/lnd/lntemp/rpc" @@ -170,19 +170,18 @@ type State struct { // openChans records each opened channel and how many times it has // heard the announcements from its graph subscription. - // openChans map[wire.OutPoint][]*OpenChannelUpdate - openChans *sync.Map + openChans *SyncMap[wire.OutPoint, []*OpenChannelUpdate] // closedChans records each closed channel and its close channel update // message received from its graph subscription. - closedChans *sync.Map + closedChans *SyncMap[wire.OutPoint, *lnrpc.ClosedChannelUpdate] // numChanUpdates records the number of channel updates seen by each // channel. - numChanUpdates *sync.Map + numChanUpdates *SyncMap[wire.OutPoint, int] // nodeUpdates records the node announcements seen by each node. - nodeUpdates *sync.Map + nodeUpdates *SyncMap[string, []*lnrpc.NodeUpdate] // policyUpdates defines a type to store channel policy updates. It has // the format, @@ -197,19 +196,21 @@ type State struct { // }, // "chanPoint2": ... // } - policyUpdates *sync.Map + policyUpdates *SyncMap[wire.OutPoint, PolicyUpdate] } // newState initialize a new state with every field being set to its zero // value. func newState(rpc *rpc.HarnessRPC) *State { return &State{ - rpc: rpc, - openChans: &sync.Map{}, - closedChans: &sync.Map{}, - numChanUpdates: &sync.Map{}, - nodeUpdates: &sync.Map{}, - policyUpdates: &sync.Map{}, + rpc: rpc, + openChans: &SyncMap[wire.OutPoint, []*OpenChannelUpdate]{}, + closedChans: &SyncMap[ + wire.OutPoint, *lnrpc.ClosedChannelUpdate, + ]{}, + numChanUpdates: &SyncMap[wire.OutPoint, int]{}, + nodeUpdates: &SyncMap[string, []*lnrpc.NodeUpdate]{}, + policyUpdates: &SyncMap[wire.OutPoint, PolicyUpdate]{}, } } diff --git a/lntemp/node/sync_map.go b/lntemp/node/sync_map.go new file mode 100644 index 000000000..1ef157331 --- /dev/null +++ b/lntemp/node/sync_map.go @@ -0,0 +1,53 @@ +package node + +import "sync" + +// SyncMap wraps a sync.Map with type parameters such that it's easier to +// access the items stored in the map since no type assertion is needed. It +// also requires explicit type definition when declaring and initiating the +// variables, which helps us understanding what's stored in a given map. +type SyncMap[K comparable, V any] struct { + sync.Map +} + +// Store puts an item in the map. +func (m *SyncMap[K, V]) Store(key K, value V) { + m.Map.Store(key, value) +} + +// Load queries an item from the map using the specified key. If the item +// cannot be found, an empty value and false will be returned. If the stored +// item fails the type assertion, a nil value and false will be returned. +func (m *SyncMap[K, V]) Load(key K) (V, bool) { + result, ok := m.Map.Load(key) + if !ok { + return *new(V), false // nolint: gocritic + } + + item, ok := result.(V) + return item, ok +} + +// Delete removes an item from the map specified by the key. +func (m *SyncMap[K, V]) Delete(key K) { + m.Map.Delete(key) +} + +// LoadAndDelete queries an item and deletes it from the map using the +// specified key. +func (m *SyncMap[K, V]) LoadAndDelete(key K) (V, bool) { + result, loaded := m.Map.LoadAndDelete(key) + if !loaded { + return *new(V), loaded // nolint: gocritic + } + + item, ok := result.(V) + return item, ok +} + +// Range iterates the map. +func (m *SyncMap[K, V]) Range(visitor func(K, V) bool) { + m.Map.Range(func(k any, v any) bool { + return visitor(k.(K), v.(V)) + }) +} diff --git a/lntemp/node/watcher.go b/lntemp/node/watcher.go index 84d90a6f9..3f6a05d95 100644 --- a/lntemp/node/watcher.go +++ b/lntemp/node/watcher.go @@ -35,9 +35,9 @@ const ( DefaultTimeout = lntest.DefaultTimeout ) -// closeChanWatchRequest is a request to the lightningNetworkWatcher to be -// notified once it's detected within the test Lightning Network, that a -// channel has either been added or closed. +// chanWatchRequest is a request to the lightningNetworkWatcher to be notified +// once it's detected within the test Lightning Network, that a channel has +// either been added or closed. type chanWatchRequest struct { chanPoint wire.OutPoint @@ -68,11 +68,8 @@ type nodeWatcher struct { // of edges seen for that channel within the network. When this number // reaches 2, then it means that both edge advertisements has // propagated through the network. - // openChanWatchers map[wire.OutPoint][]chan struct{} - openChanWatchers *sync.Map - - // closeChanWatchers map[wire.OutPoint][]chan struct{} - closeChanWatchers *sync.Map + openChanWatchers *SyncMap[wire.OutPoint, []chan struct{}] + closeChanWatchers *SyncMap[wire.OutPoint, []chan struct{}] wg sync.WaitGroup } @@ -82,37 +79,28 @@ func newNodeWatcher(rpc *rpc.HarnessRPC, state *State) *nodeWatcher { rpc: rpc, state: state, chanWatchRequests: make(chan *chanWatchRequest, 100), - openChanWatchers: &sync.Map{}, - closeChanWatchers: &sync.Map{}, + openChanWatchers: &SyncMap[wire.OutPoint, []chan struct{}]{}, + closeChanWatchers: &SyncMap[wire.OutPoint, []chan struct{}]{}, } } // GetNumChannelUpdates reads the num of channel updates inside a lock and // returns the value. func (nw *nodeWatcher) GetNumChannelUpdates(op wire.OutPoint) int { - result, ok := nw.state.numChanUpdates.Load(op) - if ok { - return result.(int) - } - return 0 + result, _ := nw.state.numChanUpdates.Load(op) + return result } // GetPolicyUpdates returns the node's policyUpdates state. func (nw *nodeWatcher) GetPolicyUpdates(op wire.OutPoint) PolicyUpdate { - result, ok := nw.state.policyUpdates.Load(op) - if ok { - return result.(PolicyUpdate) - } - return nil + result, _ := nw.state.policyUpdates.Load(op) + return result } // GetNodeUpdates reads the node updates inside a lock and returns the value. func (nw *nodeWatcher) GetNodeUpdates(pubkey string) []*lnrpc.NodeUpdate { - result, ok := nw.state.nodeUpdates.Load(pubkey) - if ok { - return result.([]*lnrpc.NodeUpdate) - } - return nil + result, _ := nw.state.nodeUpdates.Load(pubkey) + return result } // WaitForNumChannelUpdates will block until a given number of updates has been @@ -170,7 +158,7 @@ func (nw *nodeWatcher) WaitForChannelOpen(chanPoint *lnrpc.ChannelPoint) error { return nil case <-timer: - updates, err := syncMapToJSON(nw.state.openChans) + updates, err := syncMapToJSON(&nw.state.openChans.Map) if err != nil { return err } @@ -204,7 +192,7 @@ func (nw *nodeWatcher) WaitForChannelClose( "a closed channel in node's state:%s", op, nw.state) } - return closedChan.(*lnrpc.ClosedChannelUpdate), nil + return closedChan, nil case <-timer: return nil, fmt.Errorf("channel:%s not closed before timeout: "+ @@ -329,23 +317,14 @@ func (nw *nodeWatcher) handleChannelEdgeUpdates( // updateNodeStateNumChanUpdates updates the internal state of the node // regarding the num of channel update seen. func (nw *nodeWatcher) updateNodeStateNumChanUpdates(op wire.OutPoint) { - var oldNum int - result, ok := nw.state.numChanUpdates.Load(op) - if ok { - oldNum = result.(int) - } + oldNum, _ := nw.state.numChanUpdates.Load(op) nw.state.numChanUpdates.Store(op, oldNum+1) } // updateNodeStateNodeUpdates updates the internal state of the node regarding // the node updates seen. func (nw *nodeWatcher) updateNodeStateNodeUpdates(update *lnrpc.NodeUpdate) { - var oldUpdates []*lnrpc.NodeUpdate - - result, ok := nw.state.nodeUpdates.Load(update.IdentityKey) - if ok { - oldUpdates = result.([]*lnrpc.NodeUpdate) - } + oldUpdates, _ := nw.state.nodeUpdates.Load(update.IdentityKey) nw.state.nodeUpdates.Store( update.IdentityKey, append(oldUpdates, update), ) @@ -357,11 +336,7 @@ func (nw *nodeWatcher) updateNodeStateOpenChannel(op wire.OutPoint, newChan *lnrpc.ChannelEdgeUpdate) { // Load the old updates the node has heard so far. - updates := make([]*OpenChannelUpdate, 0) - result, ok := nw.state.openChans.Load(op) - if ok { - updates = result.([]*OpenChannelUpdate) - } + updates, _ := nw.state.openChans.Load(op) // Create a new update based on this newChan. newUpdate := &OpenChannelUpdate{ @@ -386,8 +361,8 @@ func (nw *nodeWatcher) updateNodeStateOpenChannel(op wire.OutPoint, if !loaded { return } - events := watcherResult.([]chan struct{}) - for _, eventChan := range events { + + for _, eventChan := range watcherResult { close(eventChan) } } @@ -399,10 +374,9 @@ func (nw *nodeWatcher) updateNodeStatePolicy(op wire.OutPoint, // Init an empty policy map and overwrite it if the channel point can // be found in the node's policyUpdates. - policies := make(PolicyUpdate) - result, ok := nw.state.policyUpdates.Load(op) - if ok { - policies = result.(PolicyUpdate) + policies, ok := nw.state.policyUpdates.Load(op) + if !ok { + policies = make(PolicyUpdate) } node := newChan.AdvertisingNode @@ -426,21 +400,17 @@ func (nw *nodeWatcher) handleOpenChannelWatchRequest(req *chanWatchRequest) { // If this is an open request, then it can be dispatched if the number // of edges seen for the channel is at least two. - result, ok := nw.state.openChans.Load(targetChan) - if ok && len(result.([]*OpenChannelUpdate)) >= 2 { + result, _ := nw.state.openChans.Load(targetChan) + if len(result) >= 2 { close(req.eventChan) return } // Otherwise, we'll add this to the list of open channel watchers for // this out point. - oldWatchers := make([]chan struct{}, 0) - watchers, ok := nw.openChanWatchers.Load(targetChan) - if ok { - oldWatchers = watchers.([]chan struct{}) - } + watchers, _ := nw.openChanWatchers.Load(targetChan) nw.openChanWatchers.Store( - targetChan, append(oldWatchers, req.eventChan), + targetChan, append(watchers, req.eventChan), ) } @@ -459,12 +429,11 @@ func (nw *nodeWatcher) handleClosedChannelUpdate( // As the channel has been closed, we'll notify all register // watchers. - result, loaded := nw.closeChanWatchers.LoadAndDelete(op) + watchers, loaded := nw.closeChanWatchers.LoadAndDelete(op) if !loaded { continue } - watchers := result.([]chan struct{}) for _, eventChan := range watchers { close(eventChan) } @@ -487,12 +456,7 @@ func (nw *nodeWatcher) handleCloseChannelWatchRequest(req *chanWatchRequest) { // Otherwise, we'll add this to the list of close channel watchers for // this out point. - oldWatchers := make([]chan struct{}, 0) - result, ok := nw.closeChanWatchers.Load(targetChan) - if ok { - oldWatchers = result.([]chan struct{}) - } - + oldWatchers, _ := nw.closeChanWatchers.Load(targetChan) nw.closeChanWatchers.Store( targetChan, append(oldWatchers, req.eventChan), ) @@ -509,9 +473,8 @@ func (nw *nodeWatcher) handlePolicyUpdateWatchRequest(req *chanWatchRequest) { // Get a list of known policies for this chanPoint+advertisingNode // combination. Start searching in the node state first. - result, ok := nw.state.policyUpdates.Load(op) + policyMap, ok := nw.state.policyUpdates.Load(op) if ok { - policyMap := result.(PolicyUpdate) policies, ok = policyMap[req.advertisingNode] if !ok { return