diff --git a/routing/router.go b/routing/router.go index 8ee2c09c4..07041845d 100644 --- a/routing/router.go +++ b/routing/router.go @@ -150,6 +150,97 @@ func newRouteTuple(amt lnwire.MilliSatoshi, dest []byte) routeTuple { return r } +// cntMutex is a struct that wraps a counter and a mutex, and is used +// to keep track of the number of goroutines waiting for access to the +// mutex, such that we can forget about it when the counter is zero. +type cntMutex struct { + cnt int + sync.Mutex +} + +// mutexForID is a struct that keeps track of a set of mutexes with +// a given ID. It can be used for making sure only one goroutine +// gets given the mutex per ID. Here it is currently used to making +// sure we only process one ChannelEdgePolicy per channelID at a +// given time. +type mutexForID struct { + // mutexes is a map of IDs to a cntMutex. The cntMutex for + // a given ID will hold the mutex to be used by all + // callers requesting access for the ID, in addition to + // the count of callers. + mutexes map[uint64]*cntMutex + + // mapMtx is used to give synchronize concurrent access + // to the mutexes map. + mapMtx sync.Mutex +} + +func newMutexForID() *mutexForID { + return &mutexForID{ + mutexes: make(map[uint64]*cntMutex), + } +} + +// Lock locks the mutex by the given ID. If the mutex is already +// locked by this ID, Lock blocks until the mutex is available. +func (c *mutexForID) Lock(id uint64) { + c.mapMtx.Lock() + mtx, ok := c.mutexes[id] + if ok { + // If the mutex already existed in the map, we + // increment its counter, to indicate that there + // now is one more goroutine waiting for it. + mtx.cnt++ + } else { + // If it was not in the map, it means no other + // goroutine has locked the mutex for this ID, + // and we can create a new mutex with count 1 + // and add it to the map. + mtx = &cntMutex{ + cnt: 1, + } + c.mutexes[id] = mtx + } + c.mapMtx.Unlock() + + // Acquire the mutex for this ID. + mtx.Lock() +} + +// Unlock unlocks the mutex by the given ID. It is a run-time +// error if the mutex is not locked by the ID on entry to Unlock. +func (c *mutexForID) Unlock(id uint64) { + // Since we are done with all the work for this + // update, we update the map to reflect that. + c.mapMtx.Lock() + + mtx, ok := c.mutexes[id] + if !ok { + // The mutex not existing in the map means + // an unlock for an ID not currently locked + // was attempted. + panic(fmt.Sprintf("double unlock for id %v", + id)) + } + + // Decrement the counter. If the count goes to + // zero, it means this caller was the last one + // to wait for the mutex, and we can delete it + // from the map. We can do this safely since we + // are under the mapMtx, meaning that all other + // goroutines waiting for the mutex already + // have incremented it, or will create a new + // mutex when they get the mapMtx. + mtx.cnt-- + if mtx.cnt == 0 { + delete(c.mutexes, id) + } + c.mapMtx.Unlock() + + // Unlock the mutex for this ID. + mtx.Unlock() +} + // ChannelRouter is the layer 3 router within the Lightning stack. Below the // ChannelRouter is the HtlcSwitch, and below that is the Bitcoin blockchain // itself. The primary role of the ChannelRouter is to respond to queries for @@ -219,6 +310,11 @@ type ChannelRouter struct { // gained to the next execution. missionControl *missionControl + // channelEdgeMtx is a mutex we use to make sure we process only one + // ChannelEdgePolicy at a time for a given channelID, to ensure + // consistency between the various database accesses. + channelEdgeMtx *mutexForID + sync.RWMutex quit chan struct{} @@ -247,6 +343,7 @@ func New(cfg Config) (*ChannelRouter, error) { topologyClients: make(map[uint64]*topologyClient), ntfnClientUpdates: make(chan *topologyClientUpdate), missionControl: newMissionControl(cfg.Graph, selfNode), + channelEdgeMtx: newMutexForID(), selfNode: selfNode, routeCache: make(map[routeTuple][]*Route), quit: make(chan struct{}), @@ -942,6 +1039,13 @@ func (r *ChannelRouter) processUpdate(msg interface{}) error { case *channeldb.ChannelEdgePolicy: channelID := lnwire.NewShortChanIDFromInt(msg.ChannelID) + + // We make sure to hold the mutex for this channel ID, + // such that no other goroutine is concurrently doing + // database accesses for the same channel ID. + r.channelEdgeMtx.Lock(msg.ChannelID) + defer r.channelEdgeMtx.Unlock(msg.ChannelID) + edge1Timestamp, edge2Timestamp, exists, err := r.cfg.Graph.HasChannelEdge( msg.ChannelID, )