diff --git a/discovery/gossiper.go b/discovery/gossiper.go index f2577fe7f..87ad835b9 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -435,7 +435,7 @@ type channelUpdateID struct { } // msgWithSenders is a wrapper struct around a message, and the set of peers -// that oreignally sent ius this message. Using this struct, we can ensure that +// that originally sent us this message. Using this struct, we can ensure that // we don't re-send a message to the peer that sent it to us in the first // place. type msgWithSenders struct { @@ -450,7 +450,9 @@ type msgWithSenders struct { // batch. Internally, announcements are stored in three maps // (one each for channel announcements, channel updates, and node // announcements). These maps keep track of unique announcements and ensure no -// announcements are duplicated. +// announcements are duplicated. We keep the three message types separate, such +// that we can send channel announcements first, then channel updates, and +// finally node announcements when it's time to broadcast them. type deDupedAnnouncements struct { // channelAnnouncements are identified by the short channel id field. channelAnnouncements map[lnwire.ShortChannelID]msgWithSenders @@ -527,12 +529,31 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { msg.Flags, } + oldTimestamp := uint32(0) mws, ok := d.channelUpdates[deDupKey] - if !ok { + if ok { + // If we already have seen this message, record its + // timestamp. + oldTimestamp = mws.msg.(*lnwire.ChannelUpdate).Timestamp + } + + // If we already had this message with a strictly newer + // timestamp, then we'll just discard the message we got. + if oldTimestamp > msg.Timestamp { + return + } + + // If the message we just got is newer than what we previously + // have seen, or this is the first time we see it, then we'll + // add it to our map of announcements. + if oldTimestamp < msg.Timestamp { mws = msgWithSenders{ msg: msg, senders: make(map[routing.Vertex]struct{}), } + + // We'll mark the sender of the message in the + // senders map. mws.senders[sender] = struct{}{} d.channelUpdates[deDupKey] = mws @@ -540,6 +561,10 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { return } + // Lastly, if we had seen this exact message from before, with + // the same timestamp, we'll add the sender to the map of + // senders, such that we can skip sending this message back in + // the next batch. mws.msg = msg mws.senders[sender] = struct{}{} d.channelUpdates[deDupKey] = mws @@ -550,12 +575,26 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { sender := routing.NewVertex(message.peer) deDupKey := routing.NewVertex(msg.NodeID) + // We do the same for node annonuncements as we did for channel + // updates, as they also carry a timestamp. + oldTimestamp := uint32(0) mws, ok := d.nodeAnnouncements[deDupKey] - if !ok { + if ok { + oldTimestamp = mws.msg.(*lnwire.NodeAnnouncement).Timestamp + } + + // Discard the message if it's old. + if oldTimestamp > msg.Timestamp { + return + } + + // Replace if it's newer. + if oldTimestamp < msg.Timestamp { mws = msgWithSenders{ msg: msg, senders: make(map[routing.Vertex]struct{}), } + mws.senders[sender] = struct{}{} d.nodeAnnouncements[deDupKey] = mws @@ -563,6 +602,7 @@ func (d *deDupedAnnouncements) addMsg(message networkMsg) { return } + // Add to senders map if it's the same as we had. mws.msg = msg mws.senders[sender] = struct{}{} d.nodeAnnouncements[deDupKey] = mws diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index 76e62dd27..ebb1d58e8 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -274,13 +274,14 @@ type annBatch struct { func createAnnouncements(blockHeight uint32) (*annBatch, error) { var err error var batch annBatch + timestamp := uint32(123456) - batch.nodeAnn1, err = createNodeAnnouncement(nodeKeyPriv1) + batch.nodeAnn1, err = createNodeAnnouncement(nodeKeyPriv1, timestamp) if err != nil { return nil, err } - batch.nodeAnn2, err = createNodeAnnouncement(nodeKeyPriv2) + batch.nodeAnn2, err = createNodeAnnouncement(nodeKeyPriv2, timestamp) if err != nil { return nil, err } @@ -310,14 +311,14 @@ func createAnnouncements(blockHeight uint32) (*annBatch, error) { batch.localChanAnn.NodeSig2 = nil batch.chanUpdAnn1, err = createUpdateAnnouncement( - blockHeight, 0, nodeKeyPriv1, + blockHeight, 0, nodeKeyPriv1, timestamp, ) if err != nil { return nil, err } batch.chanUpdAnn2, err = createUpdateAnnouncement( - blockHeight, 1, nodeKeyPriv2, + blockHeight, 1, nodeKeyPriv2, timestamp, ) if err != nil { return nil, err @@ -327,7 +328,8 @@ func createAnnouncements(blockHeight uint32) (*annBatch, error) { } -func createNodeAnnouncement(priv *btcec.PrivateKey) (*lnwire.NodeAnnouncement, +func createNodeAnnouncement(priv *btcec.PrivateKey, + timestamp uint32) (*lnwire.NodeAnnouncement, error) { var err error @@ -338,7 +340,7 @@ func createNodeAnnouncement(priv *btcec.PrivateKey) (*lnwire.NodeAnnouncement, } a := &lnwire.NodeAnnouncement{ - Timestamp: uint32(prand.Int31()), + Timestamp: timestamp, Addresses: testAddrs, NodeID: priv.PubKey(), Alias: alias, @@ -355,7 +357,8 @@ func createNodeAnnouncement(priv *btcec.PrivateKey) (*lnwire.NodeAnnouncement, } func createUpdateAnnouncement(blockHeight uint32, flags lnwire.ChanUpdateFlag, - nodeKey *btcec.PrivateKey) (*lnwire.ChannelUpdate, error) { + nodeKey *btcec.PrivateKey, timestamp uint32) (*lnwire.ChannelUpdate, + error) { var err error @@ -363,7 +366,7 @@ func createUpdateAnnouncement(blockHeight uint32, flags lnwire.ChanUpdateFlag, ShortChannelID: lnwire.ShortChannelID{ BlockHeight: blockHeight, }, - Timestamp: uint32(prand.Int31()), + Timestamp: timestamp, TimeLockDelta: uint16(prand.Int63()), Flags: flags, HtlcMinimumMsat: lnwire.MilliSatoshi(prand.Int63()), @@ -494,6 +497,8 @@ func createTestCtx(startHeight uint32) (*testCtx, func(), error) { func TestProcessAnnouncement(t *testing.T) { t.Parallel() + timestamp := uint32(123456) + ctx, cleanup, err := createTestCtx(0) if err != nil { t.Fatalf("can't create context: %v", err) @@ -511,7 +516,7 @@ func TestProcessAnnouncement(t *testing.T) { // gossiper service, check that valid announcement have been // propagated farther into the lightning network, and check that we // added new node into router. - na, err := createNodeAnnouncement(nodeKeyPriv1) + na, err := createNodeAnnouncement(nodeKeyPriv1, timestamp) if err != nil { t.Fatalf("can't create node announcement: %v", err) } @@ -567,7 +572,7 @@ func TestProcessAnnouncement(t *testing.T) { // Pretending that we received valid channel policy update from remote // side, and check that we broadcasted it to the other network, and // added updates to the router. - ua, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1) + ua, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1, timestamp) if err != nil { t.Fatalf("can't create update announcement: %v", err) } @@ -599,13 +604,15 @@ func TestProcessAnnouncement(t *testing.T) { func TestPrematureAnnouncement(t *testing.T) { t.Parallel() + timestamp := uint32(123456) + ctx, cleanup, err := createTestCtx(0) if err != nil { t.Fatalf("can't create context: %v", err) } defer cleanup() - na, err := createNodeAnnouncement(nodeKeyPriv1) + na, err := createNodeAnnouncement(nodeKeyPriv1, timestamp) if err != nil { t.Fatalf("can't create node announcement: %v", err) } @@ -633,7 +640,7 @@ func TestPrematureAnnouncement(t *testing.T) { // remote side, but block height of this announcement is greater than // highest know to us, for that reason it should be added to the // repeat/premature batch. - ua, err := createUpdateAnnouncement(1, 0, nodeKeyPriv1) + ua, err := createUpdateAnnouncement(1, 0, nodeKeyPriv1, timestamp) if err != nil { t.Fatalf("can't create update announcement: %v", err) } @@ -1568,6 +1575,7 @@ func TestSignatureAnnouncementFullProofWhenRemoteProof(t *testing.T) { func TestDeDuplicatedAnnouncements(t *testing.T) { t.Parallel() + timestamp := uint32(123456) announcements := deDupedAnnouncements{} announcements.Reset() @@ -1610,7 +1618,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { // Next, we'll ensure that channel update announcements are properly // stored and de-duplicated. We do this by creating two updates // announcements with the same short ID and flag. - ua, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1) + ua, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1, timestamp) if err != nil { t.Fatalf("can't create update announcement: %v", err) } @@ -1621,7 +1629,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { // Adding the very same announcement shouldn't cause an increase in the // number of ChannelUpdate announcements stored. - ua2, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1) + ua2, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1, timestamp) if err != nil { t.Fatalf("can't create update announcement: %v", err) } @@ -1630,9 +1638,51 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { t.Fatal("channel update not replaced in batch") } + // Adding an announcment with a later timestamp should replace the + // stored one. + ua3, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1, timestamp+1) + if err != nil { + t.Fatalf("can't create update announcement: %v", err) + } + announcements.AddMsgs(networkMsg{msg: ua3, peer: bitcoinKeyPub2}) + if len(announcements.channelUpdates) != 1 { + t.Fatal("channel update not replaced in batch") + } + + assertChannelUpdate := func(channelUpdate *lnwire.ChannelUpdate) { + channelKey := channelUpdateID{ + ua3.ShortChannelID, + ua3.Flags, + } + + mws, ok := announcements.channelUpdates[channelKey] + if !ok { + t.Fatal("channel update not in batch") + } + if mws.msg != channelUpdate { + t.Fatalf("expected channel update %v, got %v)", + channelUpdate, mws.msg) + } + } + + // Check that ua3 is the currently stored channel update. + assertChannelUpdate(ua3) + + // Adding a channel update with an earlier timestamp should NOT + // replace the one stored. + ua4, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1, timestamp) + if err != nil { + t.Fatalf("can't create update announcement: %v", err) + } + announcements.AddMsgs(networkMsg{msg: ua4, peer: bitcoinKeyPub2}) + if len(announcements.channelUpdates) != 1 { + t.Fatal("channel update not in batch") + } + assertChannelUpdate(ua3) + // Next well ensure that node announcements are properly de-duplicated. // We'll first add a single instance with a node's private key. - na, err := createNodeAnnouncement(nodeKeyPriv1) + na, err := createNodeAnnouncement(nodeKeyPriv1, timestamp) if err != nil { t.Fatalf("can't create node announcement: %v", err) } @@ -1642,7 +1692,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { } // We'll now add another node to the batch. - na2, err := createNodeAnnouncement(nodeKeyPriv2) + na2, err := createNodeAnnouncement(nodeKeyPriv2, timestamp) if err != nil { t.Fatalf("can't create node announcement: %v", err) } @@ -1653,7 +1703,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { // Adding a new instance of the _same_ node shouldn't increase the size // of the node ann batch. - na3, err := createNodeAnnouncement(nodeKeyPriv2) + na3, err := createNodeAnnouncement(nodeKeyPriv2, timestamp) if err != nil { t.Fatalf("can't create node announcement: %v", err) } @@ -1665,7 +1715,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { // Ensure that node announcement with different pointer to same public // key is still de-duplicated. newNodeKeyPointer := nodeKeyPriv2 - na4, err := createNodeAnnouncement(newNodeKeyPointer) + na4, err := createNodeAnnouncement(newNodeKeyPointer, timestamp) if err != nil { t.Fatalf("can't create node announcement: %v", err) } @@ -1674,6 +1724,26 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { t.Fatal("second node announcement not replaced again in batch") } + // Ensure that node announcement with increased timestamp replaces + // what is currently stored. + na5, err := createNodeAnnouncement(nodeKeyPriv2, timestamp+1) + if err != nil { + t.Fatalf("can't create node announcement: %v", err) + } + announcements.AddMsgs(networkMsg{msg: na5, peer: bitcoinKeyPub2}) + if len(announcements.nodeAnnouncements) != 2 { + t.Fatal("node announcement not replaced in batch") + } + nodeID := routing.NewVertex(nodeKeyPriv2.PubKey()) + stored, ok := announcements.nodeAnnouncements[nodeID] + if !ok { + t.Fatalf("node announcement not found in batch") + } + if stored.msg != na5 { + t.Fatalf("expected de-duped node announcement to be %v, got %v", + na5, stored.msg) + } + // Ensure that announcement batch delivers channel announcements, // channel updates, and node announcements in proper order. batch := announcements.Emit() @@ -1686,7 +1756,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { "expected %v", spew.Sdump(batch[0].msg), spew.Sdump(ca2)) } - if !reflect.DeepEqual(batch[1].msg, ua2) { + if !reflect.DeepEqual(batch[1].msg, ua3) { t.Fatalf("channel update not next in batch: got %v, "+ "expected %v", spew.Sdump(batch[1].msg), spew.Sdump(ua2)) } @@ -1699,10 +1769,10 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { "got %v, expected %v", batch[2].msg, na) } - if !reflect.DeepEqual(batch[2].msg, na4) && !reflect.DeepEqual(batch[3].msg, na4) { + if !reflect.DeepEqual(batch[2].msg, na5) && !reflect.DeepEqual(batch[3].msg, na5) { t.Fatalf("second node announcement not in last part of batch: "+ "got %v, expected %v", batch[3].msg, - na2) + na5) } // Ensure that after reset, storage of each announcement type diff --git a/routing/router.go b/routing/router.go index 8ee2c09c4..07041845d 100644 --- a/routing/router.go +++ b/routing/router.go @@ -150,6 +150,97 @@ func newRouteTuple(amt lnwire.MilliSatoshi, dest []byte) routeTuple { return r } +// cntMutex is a struct that wraps a counter and a mutex, and is used +// to keep track of the number of goroutines waiting for access to the +// mutex, such that we can forget about it when the counter is zero. +type cntMutex struct { + cnt int + sync.Mutex +} + +// mutexForID is a struct that keeps track of a set of mutexes with +// a given ID. It can be used for making sure only one goroutine +// gets given the mutex per ID. Here it is currently used to making +// sure we only process one ChannelEdgePolicy per channelID at a +// given time. +type mutexForID struct { + // mutexes is a map of IDs to a cntMutex. The cntMutex for + // a given ID will hold the mutex to be used by all + // callers requesting access for the ID, in addition to + // the count of callers. + mutexes map[uint64]*cntMutex + + // mapMtx is used to give synchronize concurrent access + // to the mutexes map. + mapMtx sync.Mutex +} + +func newMutexForID() *mutexForID { + return &mutexForID{ + mutexes: make(map[uint64]*cntMutex), + } +} + +// Lock locks the mutex by the given ID. If the mutex is already +// locked by this ID, Lock blocks until the mutex is available. +func (c *mutexForID) Lock(id uint64) { + c.mapMtx.Lock() + mtx, ok := c.mutexes[id] + if ok { + // If the mutex already existed in the map, we + // increment its counter, to indicate that there + // now is one more goroutine waiting for it. + mtx.cnt++ + } else { + // If it was not in the map, it means no other + // goroutine has locked the mutex for this ID, + // and we can create a new mutex with count 1 + // and add it to the map. + mtx = &cntMutex{ + cnt: 1, + } + c.mutexes[id] = mtx + } + c.mapMtx.Unlock() + + // Acquire the mutex for this ID. + mtx.Lock() +} + +// Unlock unlocks the mutex by the given ID. It is a run-time +// error if the mutex is not locked by the ID on entry to Unlock. +func (c *mutexForID) Unlock(id uint64) { + // Since we are done with all the work for this + // update, we update the map to reflect that. + c.mapMtx.Lock() + + mtx, ok := c.mutexes[id] + if !ok { + // The mutex not existing in the map means + // an unlock for an ID not currently locked + // was attempted. + panic(fmt.Sprintf("double unlock for id %v", + id)) + } + + // Decrement the counter. If the count goes to + // zero, it means this caller was the last one + // to wait for the mutex, and we can delete it + // from the map. We can do this safely since we + // are under the mapMtx, meaning that all other + // goroutines waiting for the mutex already + // have incremented it, or will create a new + // mutex when they get the mapMtx. + mtx.cnt-- + if mtx.cnt == 0 { + delete(c.mutexes, id) + } + c.mapMtx.Unlock() + + // Unlock the mutex for this ID. + mtx.Unlock() +} + // ChannelRouter is the layer 3 router within the Lightning stack. Below the // ChannelRouter is the HtlcSwitch, and below that is the Bitcoin blockchain // itself. The primary role of the ChannelRouter is to respond to queries for @@ -219,6 +310,11 @@ type ChannelRouter struct { // gained to the next execution. missionControl *missionControl + // channelEdgeMtx is a mutex we use to make sure we process only one + // ChannelEdgePolicy at a time for a given channelID, to ensure + // consistency between the various database accesses. + channelEdgeMtx *mutexForID + sync.RWMutex quit chan struct{} @@ -247,6 +343,7 @@ func New(cfg Config) (*ChannelRouter, error) { topologyClients: make(map[uint64]*topologyClient), ntfnClientUpdates: make(chan *topologyClientUpdate), missionControl: newMissionControl(cfg.Graph, selfNode), + channelEdgeMtx: newMutexForID(), selfNode: selfNode, routeCache: make(map[routeTuple][]*Route), quit: make(chan struct{}), @@ -942,6 +1039,13 @@ func (r *ChannelRouter) processUpdate(msg interface{}) error { case *channeldb.ChannelEdgePolicy: channelID := lnwire.NewShortChanIDFromInt(msg.ChannelID) + + // We make sure to hold the mutex for this channel ID, + // such that no other goroutine is concurrently doing + // database accesses for the same channel ID. + r.channelEdgeMtx.Lock(msg.ChannelID) + defer r.channelEdgeMtx.Unlock(msg.ChannelID) + edge1Timestamp, edge2Timestamp, exists, err := r.cfg.Graph.HasChannelEdge( msg.ChannelID, )