diff --git a/discovery/sync_manager.go b/discovery/sync_manager.go index 378085e49..e5cace1a4 100644 --- a/discovery/sync_manager.go +++ b/discovery/sync_manager.go @@ -200,7 +200,7 @@ type SyncManager struct { gossipFilterSema chan struct{} // rateLimiter dictates the frequency with which we will reply to gossip - // queries from a peer. This is used to delay responses to peers to + // queries from all peers. This is used to delay responses to peers to // prevent DOS vulnerabilities if they are spamming with an unreasonable // number of queries. rateLimiter *rate.Limiter @@ -554,8 +554,8 @@ func (m *SyncManager) isPinnedSyncer(s *GossipSyncer) bool { // deriveRateLimitReservation will take the current message and derive a // reservation that can be used to wait on the rate limiter. -func (m *SyncManager) deriveRateLimitReservation(msg lnwire.Message, -) (*rate.Reservation, error) { +func deriveRateLimitReservation(rl *rate.Limiter, + msg lnwire.Message) (*rate.Reservation, error) { var ( msgSize uint32 @@ -575,12 +575,12 @@ func (m *SyncManager) deriveRateLimitReservation(msg lnwire.Message, msgSize = assumedMsgSize } - return m.rateLimiter.ReserveN(time.Now(), int(msgSize)), nil + return rl.ReserveN(time.Now(), int(msgSize)), nil } // waitMsgDelay takes a delay, and waits until it has finished. -func (m *SyncManager) waitMsgDelay(ctx context.Context, peerPub [33]byte, - limitReservation *rate.Reservation) error { +func waitMsgDelay(ctx context.Context, peerPub [33]byte, + limitReservation *rate.Reservation, quit <-chan struct{}) error { // If we've already replied a handful of times, we will start to delay // responses back to the remote peer. This can help prevent DOS attacks @@ -602,7 +602,7 @@ func (m *SyncManager) waitMsgDelay(ctx context.Context, peerPub [33]byte, return ErrGossipSyncerExiting - case <-m.quit: + case <-quit: limitReservation.Cancel() return ErrGossipSyncerExiting @@ -614,15 +614,15 @@ func (m *SyncManager) waitMsgDelay(ctx context.Context, peerPub [33]byte, // maybeRateLimitMsg takes a message, and may wait a period of time to rate // limit the msg. -func (m *SyncManager) maybeRateLimitMsg(ctx context.Context, peerPub [33]byte, - msg lnwire.Message) error { +func maybeRateLimitMsg(ctx context.Context, rl *rate.Limiter, peerPub [33]byte, + msg lnwire.Message, quit <-chan struct{}) error { - delay, err := m.deriveRateLimitReservation(msg) + delay, err := deriveRateLimitReservation(rl, msg) if err != nil { return nil } - return m.waitMsgDelay(ctx, peerPub, delay) + return waitMsgDelay(ctx, peerPub, delay, quit) } // sendMessages sends a set of messages to the remote peer. @@ -630,9 +630,13 @@ func (m *SyncManager) sendMessages(ctx context.Context, sync bool, peer lnpeer.Peer, nodeID route.Vertex, msgs ...lnwire.Message) error { for _, msg := range msgs { - if err := m.maybeRateLimitMsg(ctx, nodeID, msg); err != nil { + err := maybeRateLimitMsg( + ctx, m.rateLimiter, nodeID, msg, m.quit, + ) + if err != nil { return err } + if err := peer.SendMessageLazy(sync, msg); err != nil { return err } diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index 4aff5b631..f78e0ce6e 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -748,7 +748,7 @@ func TestDeriveRateLimitReservation(t *testing.T) { } // First message should have no delay as it fits within burst. - delay1, err := sm.deriveRateLimitReservation(msg) + delay1, err := deriveRateLimitReservation(sm.rateLimiter, msg) require.NoError(t, err) require.Equal( t, time.Duration(0), delay1.Delay(), "first message "+ @@ -757,7 +757,7 @@ func TestDeriveRateLimitReservation(t *testing.T) { // Second message should have a non-zero delay as the token // bucket is now depleted. - delay2, err := sm.deriveRateLimitReservation(msg) + delay2, err := deriveRateLimitReservation(sm.rateLimiter, msg) require.NoError(t, err) require.True( t, delay2.Delay() > 0, "second message should have "+ @@ -766,7 +766,7 @@ func TestDeriveRateLimitReservation(t *testing.T) { // Third message should have an even longer delay since the // token bucket is still refilling at a constant rate. - delay3, err := sm.deriveRateLimitReservation(msg) + delay3, err := deriveRateLimitReservation(sm.rateLimiter, msg) require.NoError(t, err) require.True(t, delay3.Delay() > delay2.Delay(), "third "+ "message should have longer delay than second: %s > %s", @@ -798,7 +798,7 @@ func TestDeriveRateLimitReservation(t *testing.T) { // The error should propagate through // deriveRateLimitReservation. - _, err := sm.deriveRateLimitReservation(msg) + _, err := deriveRateLimitReservation(sm.rateLimiter, msg) require.Error(t, err) require.Equal( t, errMsg, err, "Error should be propagated unchanged", @@ -815,7 +815,7 @@ func TestDeriveRateLimitReservation(t *testing.T) { initialMsg := &TestSizeableMessage{ size: uint32(bytesBurst), } - _, err := sm.deriveRateLimitReservation(initialMsg) + _, err := deriveRateLimitReservation(sm.rateLimiter, initialMsg) require.NoError(t, err) // Now send two messages of different sizes and compare their @@ -828,18 +828,22 @@ func TestDeriveRateLimitReservation(t *testing.T) { } // Send the small message first. - smallDelay, err := sm.deriveRateLimitReservation(smallMsg) + smallDelay, err := deriveRateLimitReservation( + sm.rateLimiter, smallMsg, + ) require.NoError(t, err) // Reset the limiter to the same state, then empty the bucket. sm.rateLimiter = rate.NewLimiter( rate.Limit(bytesPerSec), int(bytesBurst), ) - _, err = sm.deriveRateLimitReservation(initialMsg) + _, err = deriveRateLimitReservation(sm.rateLimiter, initialMsg) require.NoError(t, err) // Now send the large message. - largeDelay, err := sm.deriveRateLimitReservation(largeMsg) + largeDelay, err := deriveRateLimitReservation( + sm.rateLimiter, largeMsg, + ) require.NoError(t, err) // The large message should have a longer delay than the small