lntemp: introduce SyncMap to store type information

This commit replaces the usage of `sync.Map` with the new struct
`SyncMap` to explicitly express the type info used in the map.
This commit is contained in:
yyforyongyu 2022-10-11 22:26:28 +08:00
parent 30ebacb888
commit cf0e0820d6
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
3 changed files with 97 additions and 80 deletions

View File

@ -4,9 +4,9 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"math" "math"
"sync"
"time" "time"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/walletrpc" "github.com/lightningnetwork/lnd/lnrpc/walletrpc"
"github.com/lightningnetwork/lnd/lntemp/rpc" "github.com/lightningnetwork/lnd/lntemp/rpc"
@ -170,19 +170,18 @@ type State struct {
// openChans records each opened channel and how many times it has // openChans records each opened channel and how many times it has
// heard the announcements from its graph subscription. // heard the announcements from its graph subscription.
// openChans map[wire.OutPoint][]*OpenChannelUpdate openChans *SyncMap[wire.OutPoint, []*OpenChannelUpdate]
openChans *sync.Map
// closedChans records each closed channel and its close channel update // closedChans records each closed channel and its close channel update
// message received from its graph subscription. // message received from its graph subscription.
closedChans *sync.Map closedChans *SyncMap[wire.OutPoint, *lnrpc.ClosedChannelUpdate]
// numChanUpdates records the number of channel updates seen by each // numChanUpdates records the number of channel updates seen by each
// channel. // channel.
numChanUpdates *sync.Map numChanUpdates *SyncMap[wire.OutPoint, int]
// nodeUpdates records the node announcements seen by each node. // nodeUpdates records the node announcements seen by each node.
nodeUpdates *sync.Map nodeUpdates *SyncMap[string, []*lnrpc.NodeUpdate]
// policyUpdates defines a type to store channel policy updates. It has // policyUpdates defines a type to store channel policy updates. It has
// the format, // the format,
@ -197,19 +196,21 @@ type State struct {
// }, // },
// "chanPoint2": ... // "chanPoint2": ...
// } // }
policyUpdates *sync.Map policyUpdates *SyncMap[wire.OutPoint, PolicyUpdate]
} }
// newState initialize a new state with every field being set to its zero // newState initialize a new state with every field being set to its zero
// value. // value.
func newState(rpc *rpc.HarnessRPC) *State { func newState(rpc *rpc.HarnessRPC) *State {
return &State{ return &State{
rpc: rpc, rpc: rpc,
openChans: &sync.Map{}, openChans: &SyncMap[wire.OutPoint, []*OpenChannelUpdate]{},
closedChans: &sync.Map{}, closedChans: &SyncMap[
numChanUpdates: &sync.Map{}, wire.OutPoint, *lnrpc.ClosedChannelUpdate,
nodeUpdates: &sync.Map{}, ]{},
policyUpdates: &sync.Map{}, numChanUpdates: &SyncMap[wire.OutPoint, int]{},
nodeUpdates: &SyncMap[string, []*lnrpc.NodeUpdate]{},
policyUpdates: &SyncMap[wire.OutPoint, PolicyUpdate]{},
} }
} }

53
lntemp/node/sync_map.go Normal file
View File

@ -0,0 +1,53 @@
package node
import "sync"
// SyncMap wraps a sync.Map with type parameters such that it's easier to
// access the items stored in the map since no type assertion is needed. It
// also requires explicit type definition when declaring and initiating the
// variables, which helps us understanding what's stored in a given map.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// Store puts an item in the map.
func (m *SyncMap[K, V]) Store(key K, value V) {
m.Map.Store(key, value)
}
// Load queries an item from the map using the specified key. If the item
// cannot be found, an empty value and false will be returned. If the stored
// item fails the type assertion, a nil value and false will be returned.
func (m *SyncMap[K, V]) Load(key K) (V, bool) {
result, ok := m.Map.Load(key)
if !ok {
return *new(V), false // nolint: gocritic
}
item, ok := result.(V)
return item, ok
}
// Delete removes an item from the map specified by the key.
func (m *SyncMap[K, V]) Delete(key K) {
m.Map.Delete(key)
}
// LoadAndDelete queries an item and deletes it from the map using the
// specified key.
func (m *SyncMap[K, V]) LoadAndDelete(key K) (V, bool) {
result, loaded := m.Map.LoadAndDelete(key)
if !loaded {
return *new(V), loaded // nolint: gocritic
}
item, ok := result.(V)
return item, ok
}
// Range iterates the map.
func (m *SyncMap[K, V]) Range(visitor func(K, V) bool) {
m.Map.Range(func(k any, v any) bool {
return visitor(k.(K), v.(V))
})
}

View File

@ -35,9 +35,9 @@ const (
DefaultTimeout = lntest.DefaultTimeout DefaultTimeout = lntest.DefaultTimeout
) )
// closeChanWatchRequest is a request to the lightningNetworkWatcher to be // chanWatchRequest is a request to the lightningNetworkWatcher to be notified
// notified once it's detected within the test Lightning Network, that a // once it's detected within the test Lightning Network, that a channel has
// channel has either been added or closed. // either been added or closed.
type chanWatchRequest struct { type chanWatchRequest struct {
chanPoint wire.OutPoint chanPoint wire.OutPoint
@ -68,11 +68,8 @@ type nodeWatcher struct {
// of edges seen for that channel within the network. When this number // of edges seen for that channel within the network. When this number
// reaches 2, then it means that both edge advertisements has // reaches 2, then it means that both edge advertisements has
// propagated through the network. // propagated through the network.
// openChanWatchers map[wire.OutPoint][]chan struct{} openChanWatchers *SyncMap[wire.OutPoint, []chan struct{}]
openChanWatchers *sync.Map closeChanWatchers *SyncMap[wire.OutPoint, []chan struct{}]
// closeChanWatchers map[wire.OutPoint][]chan struct{}
closeChanWatchers *sync.Map
wg sync.WaitGroup wg sync.WaitGroup
} }
@ -82,37 +79,28 @@ func newNodeWatcher(rpc *rpc.HarnessRPC, state *State) *nodeWatcher {
rpc: rpc, rpc: rpc,
state: state, state: state,
chanWatchRequests: make(chan *chanWatchRequest, 100), chanWatchRequests: make(chan *chanWatchRequest, 100),
openChanWatchers: &sync.Map{}, openChanWatchers: &SyncMap[wire.OutPoint, []chan struct{}]{},
closeChanWatchers: &sync.Map{}, closeChanWatchers: &SyncMap[wire.OutPoint, []chan struct{}]{},
} }
} }
// GetNumChannelUpdates reads the num of channel updates inside a lock and // GetNumChannelUpdates reads the num of channel updates inside a lock and
// returns the value. // returns the value.
func (nw *nodeWatcher) GetNumChannelUpdates(op wire.OutPoint) int { func (nw *nodeWatcher) GetNumChannelUpdates(op wire.OutPoint) int {
result, ok := nw.state.numChanUpdates.Load(op) result, _ := nw.state.numChanUpdates.Load(op)
if ok { return result
return result.(int)
}
return 0
} }
// GetPolicyUpdates returns the node's policyUpdates state. // GetPolicyUpdates returns the node's policyUpdates state.
func (nw *nodeWatcher) GetPolicyUpdates(op wire.OutPoint) PolicyUpdate { func (nw *nodeWatcher) GetPolicyUpdates(op wire.OutPoint) PolicyUpdate {
result, ok := nw.state.policyUpdates.Load(op) result, _ := nw.state.policyUpdates.Load(op)
if ok { return result
return result.(PolicyUpdate)
}
return nil
} }
// GetNodeUpdates reads the node updates inside a lock and returns the value. // GetNodeUpdates reads the node updates inside a lock and returns the value.
func (nw *nodeWatcher) GetNodeUpdates(pubkey string) []*lnrpc.NodeUpdate { func (nw *nodeWatcher) GetNodeUpdates(pubkey string) []*lnrpc.NodeUpdate {
result, ok := nw.state.nodeUpdates.Load(pubkey) result, _ := nw.state.nodeUpdates.Load(pubkey)
if ok { return result
return result.([]*lnrpc.NodeUpdate)
}
return nil
} }
// WaitForNumChannelUpdates will block until a given number of updates has been // WaitForNumChannelUpdates will block until a given number of updates has been
@ -170,7 +158,7 @@ func (nw *nodeWatcher) WaitForChannelOpen(chanPoint *lnrpc.ChannelPoint) error {
return nil return nil
case <-timer: case <-timer:
updates, err := syncMapToJSON(nw.state.openChans) updates, err := syncMapToJSON(&nw.state.openChans.Map)
if err != nil { if err != nil {
return err return err
} }
@ -204,7 +192,7 @@ func (nw *nodeWatcher) WaitForChannelClose(
"a closed channel in node's state:%s", op, "a closed channel in node's state:%s", op,
nw.state) nw.state)
} }
return closedChan.(*lnrpc.ClosedChannelUpdate), nil return closedChan, nil
case <-timer: case <-timer:
return nil, fmt.Errorf("channel:%s not closed before timeout: "+ return nil, fmt.Errorf("channel:%s not closed before timeout: "+
@ -329,23 +317,14 @@ func (nw *nodeWatcher) handleChannelEdgeUpdates(
// updateNodeStateNumChanUpdates updates the internal state of the node // updateNodeStateNumChanUpdates updates the internal state of the node
// regarding the num of channel update seen. // regarding the num of channel update seen.
func (nw *nodeWatcher) updateNodeStateNumChanUpdates(op wire.OutPoint) { func (nw *nodeWatcher) updateNodeStateNumChanUpdates(op wire.OutPoint) {
var oldNum int oldNum, _ := nw.state.numChanUpdates.Load(op)
result, ok := nw.state.numChanUpdates.Load(op)
if ok {
oldNum = result.(int)
}
nw.state.numChanUpdates.Store(op, oldNum+1) nw.state.numChanUpdates.Store(op, oldNum+1)
} }
// updateNodeStateNodeUpdates updates the internal state of the node regarding // updateNodeStateNodeUpdates updates the internal state of the node regarding
// the node updates seen. // the node updates seen.
func (nw *nodeWatcher) updateNodeStateNodeUpdates(update *lnrpc.NodeUpdate) { func (nw *nodeWatcher) updateNodeStateNodeUpdates(update *lnrpc.NodeUpdate) {
var oldUpdates []*lnrpc.NodeUpdate oldUpdates, _ := nw.state.nodeUpdates.Load(update.IdentityKey)
result, ok := nw.state.nodeUpdates.Load(update.IdentityKey)
if ok {
oldUpdates = result.([]*lnrpc.NodeUpdate)
}
nw.state.nodeUpdates.Store( nw.state.nodeUpdates.Store(
update.IdentityKey, append(oldUpdates, update), update.IdentityKey, append(oldUpdates, update),
) )
@ -357,11 +336,7 @@ func (nw *nodeWatcher) updateNodeStateOpenChannel(op wire.OutPoint,
newChan *lnrpc.ChannelEdgeUpdate) { newChan *lnrpc.ChannelEdgeUpdate) {
// Load the old updates the node has heard so far. // Load the old updates the node has heard so far.
updates := make([]*OpenChannelUpdate, 0) updates, _ := nw.state.openChans.Load(op)
result, ok := nw.state.openChans.Load(op)
if ok {
updates = result.([]*OpenChannelUpdate)
}
// Create a new update based on this newChan. // Create a new update based on this newChan.
newUpdate := &OpenChannelUpdate{ newUpdate := &OpenChannelUpdate{
@ -386,8 +361,8 @@ func (nw *nodeWatcher) updateNodeStateOpenChannel(op wire.OutPoint,
if !loaded { if !loaded {
return return
} }
events := watcherResult.([]chan struct{})
for _, eventChan := range events { for _, eventChan := range watcherResult {
close(eventChan) close(eventChan)
} }
} }
@ -399,10 +374,9 @@ func (nw *nodeWatcher) updateNodeStatePolicy(op wire.OutPoint,
// Init an empty policy map and overwrite it if the channel point can // Init an empty policy map and overwrite it if the channel point can
// be found in the node's policyUpdates. // be found in the node's policyUpdates.
policies := make(PolicyUpdate) policies, ok := nw.state.policyUpdates.Load(op)
result, ok := nw.state.policyUpdates.Load(op) if !ok {
if ok { policies = make(PolicyUpdate)
policies = result.(PolicyUpdate)
} }
node := newChan.AdvertisingNode node := newChan.AdvertisingNode
@ -426,21 +400,17 @@ func (nw *nodeWatcher) handleOpenChannelWatchRequest(req *chanWatchRequest) {
// If this is an open request, then it can be dispatched if the number // If this is an open request, then it can be dispatched if the number
// of edges seen for the channel is at least two. // of edges seen for the channel is at least two.
result, ok := nw.state.openChans.Load(targetChan) result, _ := nw.state.openChans.Load(targetChan)
if ok && len(result.([]*OpenChannelUpdate)) >= 2 { if len(result) >= 2 {
close(req.eventChan) close(req.eventChan)
return return
} }
// Otherwise, we'll add this to the list of open channel watchers for // Otherwise, we'll add this to the list of open channel watchers for
// this out point. // this out point.
oldWatchers := make([]chan struct{}, 0) watchers, _ := nw.openChanWatchers.Load(targetChan)
watchers, ok := nw.openChanWatchers.Load(targetChan)
if ok {
oldWatchers = watchers.([]chan struct{})
}
nw.openChanWatchers.Store( nw.openChanWatchers.Store(
targetChan, append(oldWatchers, req.eventChan), targetChan, append(watchers, req.eventChan),
) )
} }
@ -459,12 +429,11 @@ func (nw *nodeWatcher) handleClosedChannelUpdate(
// As the channel has been closed, we'll notify all register // As the channel has been closed, we'll notify all register
// watchers. // watchers.
result, loaded := nw.closeChanWatchers.LoadAndDelete(op) watchers, loaded := nw.closeChanWatchers.LoadAndDelete(op)
if !loaded { if !loaded {
continue continue
} }
watchers := result.([]chan struct{})
for _, eventChan := range watchers { for _, eventChan := range watchers {
close(eventChan) close(eventChan)
} }
@ -487,12 +456,7 @@ func (nw *nodeWatcher) handleCloseChannelWatchRequest(req *chanWatchRequest) {
// Otherwise, we'll add this to the list of close channel watchers for // Otherwise, we'll add this to the list of close channel watchers for
// this out point. // this out point.
oldWatchers := make([]chan struct{}, 0) oldWatchers, _ := nw.closeChanWatchers.Load(targetChan)
result, ok := nw.closeChanWatchers.Load(targetChan)
if ok {
oldWatchers = result.([]chan struct{})
}
nw.closeChanWatchers.Store( nw.closeChanWatchers.Store(
targetChan, append(oldWatchers, req.eventChan), targetChan, append(oldWatchers, req.eventChan),
) )
@ -509,9 +473,8 @@ func (nw *nodeWatcher) handlePolicyUpdateWatchRequest(req *chanWatchRequest) {
// Get a list of known policies for this chanPoint+advertisingNode // Get a list of known policies for this chanPoint+advertisingNode
// combination. Start searching in the node state first. // combination. Start searching in the node state first.
result, ok := nw.state.policyUpdates.Load(op) policyMap, ok := nw.state.policyUpdates.Load(op)
if ok { if ok {
policyMap := result.(PolicyUpdate)
policies, ok = policyMap[req.advertisingNode] policies, ok = policyMap[req.advertisingNode]
if !ok { if !ok {
return return