lntemp: introduce SyncMap to store type information

This commit replaces the usage of `sync.Map` with the new struct
`SyncMap` to explicitly express the type info used in the map.
This commit is contained in:
yyforyongyu 2022-10-11 22:26:28 +08:00
parent 30ebacb888
commit cf0e0820d6
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
3 changed files with 97 additions and 80 deletions

View File

@ -4,9 +4,9 @@ import (
"encoding/json"
"fmt"
"math"
"sync"
"time"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
"github.com/lightningnetwork/lnd/lntemp/rpc"
@ -170,19 +170,18 @@ type State struct {
// openChans records each opened channel and how many times it has
// heard the announcements from its graph subscription.
// openChans map[wire.OutPoint][]*OpenChannelUpdate
openChans *sync.Map
openChans *SyncMap[wire.OutPoint, []*OpenChannelUpdate]
// closedChans records each closed channel and its close channel update
// message received from its graph subscription.
closedChans *sync.Map
closedChans *SyncMap[wire.OutPoint, *lnrpc.ClosedChannelUpdate]
// numChanUpdates records the number of channel updates seen by each
// channel.
numChanUpdates *sync.Map
numChanUpdates *SyncMap[wire.OutPoint, int]
// nodeUpdates records the node announcements seen by each node.
nodeUpdates *sync.Map
nodeUpdates *SyncMap[string, []*lnrpc.NodeUpdate]
// policyUpdates defines a type to store channel policy updates. It has
// the format,
@ -197,19 +196,21 @@ type State struct {
// },
// "chanPoint2": ...
// }
policyUpdates *sync.Map
policyUpdates *SyncMap[wire.OutPoint, PolicyUpdate]
}
// newState initialize a new state with every field being set to its zero
// value.
func newState(rpc *rpc.HarnessRPC) *State {
return &State{
rpc: rpc,
openChans: &sync.Map{},
closedChans: &sync.Map{},
numChanUpdates: &sync.Map{},
nodeUpdates: &sync.Map{},
policyUpdates: &sync.Map{},
rpc: rpc,
openChans: &SyncMap[wire.OutPoint, []*OpenChannelUpdate]{},
closedChans: &SyncMap[
wire.OutPoint, *lnrpc.ClosedChannelUpdate,
]{},
numChanUpdates: &SyncMap[wire.OutPoint, int]{},
nodeUpdates: &SyncMap[string, []*lnrpc.NodeUpdate]{},
policyUpdates: &SyncMap[wire.OutPoint, PolicyUpdate]{},
}
}

53
lntemp/node/sync_map.go Normal file
View File

@ -0,0 +1,53 @@
package node
import "sync"
// SyncMap wraps a sync.Map with type parameters such that it's easier to
// access the items stored in the map since no type assertion is needed. It
// also requires explicit type definition when declaring and initiating the
// variables, which helps us understanding what's stored in a given map.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// Store puts an item in the map.
func (m *SyncMap[K, V]) Store(key K, value V) {
m.Map.Store(key, value)
}
// Load queries an item from the map using the specified key. If the item
// cannot be found, an empty value and false will be returned. If the stored
// item fails the type assertion, a nil value and false will be returned.
func (m *SyncMap[K, V]) Load(key K) (V, bool) {
result, ok := m.Map.Load(key)
if !ok {
return *new(V), false // nolint: gocritic
}
item, ok := result.(V)
return item, ok
}
// Delete removes an item from the map specified by the key.
func (m *SyncMap[K, V]) Delete(key K) {
m.Map.Delete(key)
}
// LoadAndDelete queries an item and deletes it from the map using the
// specified key.
func (m *SyncMap[K, V]) LoadAndDelete(key K) (V, bool) {
result, loaded := m.Map.LoadAndDelete(key)
if !loaded {
return *new(V), loaded // nolint: gocritic
}
item, ok := result.(V)
return item, ok
}
// Range iterates the map.
func (m *SyncMap[K, V]) Range(visitor func(K, V) bool) {
m.Map.Range(func(k any, v any) bool {
return visitor(k.(K), v.(V))
})
}

View File

@ -35,9 +35,9 @@ const (
DefaultTimeout = lntest.DefaultTimeout
)
// closeChanWatchRequest is a request to the lightningNetworkWatcher to be
// notified once it's detected within the test Lightning Network, that a
// channel has either been added or closed.
// chanWatchRequest is a request to the lightningNetworkWatcher to be notified
// once it's detected within the test Lightning Network, that a channel has
// either been added or closed.
type chanWatchRequest struct {
chanPoint wire.OutPoint
@ -68,11 +68,8 @@ type nodeWatcher struct {
// of edges seen for that channel within the network. When this number
// reaches 2, then it means that both edge advertisements has
// propagated through the network.
// openChanWatchers map[wire.OutPoint][]chan struct{}
openChanWatchers *sync.Map
// closeChanWatchers map[wire.OutPoint][]chan struct{}
closeChanWatchers *sync.Map
openChanWatchers *SyncMap[wire.OutPoint, []chan struct{}]
closeChanWatchers *SyncMap[wire.OutPoint, []chan struct{}]
wg sync.WaitGroup
}
@ -82,37 +79,28 @@ func newNodeWatcher(rpc *rpc.HarnessRPC, state *State) *nodeWatcher {
rpc: rpc,
state: state,
chanWatchRequests: make(chan *chanWatchRequest, 100),
openChanWatchers: &sync.Map{},
closeChanWatchers: &sync.Map{},
openChanWatchers: &SyncMap[wire.OutPoint, []chan struct{}]{},
closeChanWatchers: &SyncMap[wire.OutPoint, []chan struct{}]{},
}
}
// GetNumChannelUpdates reads the num of channel updates inside a lock and
// returns the value.
func (nw *nodeWatcher) GetNumChannelUpdates(op wire.OutPoint) int {
result, ok := nw.state.numChanUpdates.Load(op)
if ok {
return result.(int)
}
return 0
result, _ := nw.state.numChanUpdates.Load(op)
return result
}
// GetPolicyUpdates returns the node's policyUpdates state.
func (nw *nodeWatcher) GetPolicyUpdates(op wire.OutPoint) PolicyUpdate {
result, ok := nw.state.policyUpdates.Load(op)
if ok {
return result.(PolicyUpdate)
}
return nil
result, _ := nw.state.policyUpdates.Load(op)
return result
}
// GetNodeUpdates reads the node updates inside a lock and returns the value.
func (nw *nodeWatcher) GetNodeUpdates(pubkey string) []*lnrpc.NodeUpdate {
result, ok := nw.state.nodeUpdates.Load(pubkey)
if ok {
return result.([]*lnrpc.NodeUpdate)
}
return nil
result, _ := nw.state.nodeUpdates.Load(pubkey)
return result
}
// WaitForNumChannelUpdates will block until a given number of updates has been
@ -170,7 +158,7 @@ func (nw *nodeWatcher) WaitForChannelOpen(chanPoint *lnrpc.ChannelPoint) error {
return nil
case <-timer:
updates, err := syncMapToJSON(nw.state.openChans)
updates, err := syncMapToJSON(&nw.state.openChans.Map)
if err != nil {
return err
}
@ -204,7 +192,7 @@ func (nw *nodeWatcher) WaitForChannelClose(
"a closed channel in node's state:%s", op,
nw.state)
}
return closedChan.(*lnrpc.ClosedChannelUpdate), nil
return closedChan, nil
case <-timer:
return nil, fmt.Errorf("channel:%s not closed before timeout: "+
@ -329,23 +317,14 @@ func (nw *nodeWatcher) handleChannelEdgeUpdates(
// updateNodeStateNumChanUpdates updates the internal state of the node
// regarding the num of channel update seen.
func (nw *nodeWatcher) updateNodeStateNumChanUpdates(op wire.OutPoint) {
var oldNum int
result, ok := nw.state.numChanUpdates.Load(op)
if ok {
oldNum = result.(int)
}
oldNum, _ := nw.state.numChanUpdates.Load(op)
nw.state.numChanUpdates.Store(op, oldNum+1)
}
// updateNodeStateNodeUpdates updates the internal state of the node regarding
// the node updates seen.
func (nw *nodeWatcher) updateNodeStateNodeUpdates(update *lnrpc.NodeUpdate) {
var oldUpdates []*lnrpc.NodeUpdate
result, ok := nw.state.nodeUpdates.Load(update.IdentityKey)
if ok {
oldUpdates = result.([]*lnrpc.NodeUpdate)
}
oldUpdates, _ := nw.state.nodeUpdates.Load(update.IdentityKey)
nw.state.nodeUpdates.Store(
update.IdentityKey, append(oldUpdates, update),
)
@ -357,11 +336,7 @@ func (nw *nodeWatcher) updateNodeStateOpenChannel(op wire.OutPoint,
newChan *lnrpc.ChannelEdgeUpdate) {
// Load the old updates the node has heard so far.
updates := make([]*OpenChannelUpdate, 0)
result, ok := nw.state.openChans.Load(op)
if ok {
updates = result.([]*OpenChannelUpdate)
}
updates, _ := nw.state.openChans.Load(op)
// Create a new update based on this newChan.
newUpdate := &OpenChannelUpdate{
@ -386,8 +361,8 @@ func (nw *nodeWatcher) updateNodeStateOpenChannel(op wire.OutPoint,
if !loaded {
return
}
events := watcherResult.([]chan struct{})
for _, eventChan := range events {
for _, eventChan := range watcherResult {
close(eventChan)
}
}
@ -399,10 +374,9 @@ func (nw *nodeWatcher) updateNodeStatePolicy(op wire.OutPoint,
// Init an empty policy map and overwrite it if the channel point can
// be found in the node's policyUpdates.
policies := make(PolicyUpdate)
result, ok := nw.state.policyUpdates.Load(op)
if ok {
policies = result.(PolicyUpdate)
policies, ok := nw.state.policyUpdates.Load(op)
if !ok {
policies = make(PolicyUpdate)
}
node := newChan.AdvertisingNode
@ -426,21 +400,17 @@ func (nw *nodeWatcher) handleOpenChannelWatchRequest(req *chanWatchRequest) {
// If this is an open request, then it can be dispatched if the number
// of edges seen for the channel is at least two.
result, ok := nw.state.openChans.Load(targetChan)
if ok && len(result.([]*OpenChannelUpdate)) >= 2 {
result, _ := nw.state.openChans.Load(targetChan)
if len(result) >= 2 {
close(req.eventChan)
return
}
// Otherwise, we'll add this to the list of open channel watchers for
// this out point.
oldWatchers := make([]chan struct{}, 0)
watchers, ok := nw.openChanWatchers.Load(targetChan)
if ok {
oldWatchers = watchers.([]chan struct{})
}
watchers, _ := nw.openChanWatchers.Load(targetChan)
nw.openChanWatchers.Store(
targetChan, append(oldWatchers, req.eventChan),
targetChan, append(watchers, req.eventChan),
)
}
@ -459,12 +429,11 @@ func (nw *nodeWatcher) handleClosedChannelUpdate(
// As the channel has been closed, we'll notify all register
// watchers.
result, loaded := nw.closeChanWatchers.LoadAndDelete(op)
watchers, loaded := nw.closeChanWatchers.LoadAndDelete(op)
if !loaded {
continue
}
watchers := result.([]chan struct{})
for _, eventChan := range watchers {
close(eventChan)
}
@ -487,12 +456,7 @@ func (nw *nodeWatcher) handleCloseChannelWatchRequest(req *chanWatchRequest) {
// Otherwise, we'll add this to the list of close channel watchers for
// this out point.
oldWatchers := make([]chan struct{}, 0)
result, ok := nw.closeChanWatchers.Load(targetChan)
if ok {
oldWatchers = result.([]chan struct{})
}
oldWatchers, _ := nw.closeChanWatchers.Load(targetChan)
nw.closeChanWatchers.Store(
targetChan, append(oldWatchers, req.eventChan),
)
@ -509,9 +473,8 @@ func (nw *nodeWatcher) handlePolicyUpdateWatchRequest(req *chanWatchRequest) {
// Get a list of known policies for this chanPoint+advertisingNode
// combination. Start searching in the node state first.
result, ok := nw.state.policyUpdates.Load(op)
policyMap, ok := nw.state.policyUpdates.Load(op)
if ok {
policyMap := result.(PolicyUpdate)
policies, ok = policyMap[req.advertisingNode]
if !ok {
return