discovery: create common helper methods for rate limiter

This allows us to reuse them in the upcoming commits where we introduce
a rate limiter to the gossip syncer.
This commit is contained in:
yyforyongyu
2025-07-22 19:03:17 +08:00
parent 1666764690
commit 19bc941cbd
2 changed files with 28 additions and 20 deletions

View File

@@ -200,7 +200,7 @@ type SyncManager struct {
gossipFilterSema chan struct{} gossipFilterSema chan struct{}
// rateLimiter dictates the frequency with which we will reply to gossip // rateLimiter dictates the frequency with which we will reply to gossip
// queries from a peer. This is used to delay responses to peers to // queries from all peers. This is used to delay responses to peers to
// prevent DOS vulnerabilities if they are spamming with an unreasonable // prevent DOS vulnerabilities if they are spamming with an unreasonable
// number of queries. // number of queries.
rateLimiter *rate.Limiter rateLimiter *rate.Limiter
@@ -554,8 +554,8 @@ func (m *SyncManager) isPinnedSyncer(s *GossipSyncer) bool {
// deriveRateLimitReservation will take the current message and derive a // deriveRateLimitReservation will take the current message and derive a
// reservation that can be used to wait on the rate limiter. // reservation that can be used to wait on the rate limiter.
func (m *SyncManager) deriveRateLimitReservation(msg lnwire.Message, func deriveRateLimitReservation(rl *rate.Limiter,
) (*rate.Reservation, error) { msg lnwire.Message) (*rate.Reservation, error) {
var ( var (
msgSize uint32 msgSize uint32
@@ -575,12 +575,12 @@ func (m *SyncManager) deriveRateLimitReservation(msg lnwire.Message,
msgSize = assumedMsgSize msgSize = assumedMsgSize
} }
return m.rateLimiter.ReserveN(time.Now(), int(msgSize)), nil return rl.ReserveN(time.Now(), int(msgSize)), nil
} }
// waitMsgDelay takes a delay, and waits until it has finished. // waitMsgDelay takes a delay, and waits until it has finished.
func (m *SyncManager) waitMsgDelay(ctx context.Context, peerPub [33]byte, func waitMsgDelay(ctx context.Context, peerPub [33]byte,
limitReservation *rate.Reservation) error { limitReservation *rate.Reservation, quit <-chan struct{}) error {
// If we've already replied a handful of times, we will start to delay // If we've already replied a handful of times, we will start to delay
// responses back to the remote peer. This can help prevent DOS attacks // responses back to the remote peer. This can help prevent DOS attacks
@@ -602,7 +602,7 @@ func (m *SyncManager) waitMsgDelay(ctx context.Context, peerPub [33]byte,
return ErrGossipSyncerExiting return ErrGossipSyncerExiting
case <-m.quit: case <-quit:
limitReservation.Cancel() limitReservation.Cancel()
return ErrGossipSyncerExiting return ErrGossipSyncerExiting
@@ -614,15 +614,15 @@ func (m *SyncManager) waitMsgDelay(ctx context.Context, peerPub [33]byte,
// maybeRateLimitMsg takes a message, and may wait a period of time to rate // maybeRateLimitMsg takes a message, and may wait a period of time to rate
// limit the msg. // limit the msg.
func (m *SyncManager) maybeRateLimitMsg(ctx context.Context, peerPub [33]byte, func maybeRateLimitMsg(ctx context.Context, rl *rate.Limiter, peerPub [33]byte,
msg lnwire.Message) error { msg lnwire.Message, quit <-chan struct{}) error {
delay, err := m.deriveRateLimitReservation(msg) delay, err := deriveRateLimitReservation(rl, msg)
if err != nil { if err != nil {
return nil return nil
} }
return m.waitMsgDelay(ctx, peerPub, delay) return waitMsgDelay(ctx, peerPub, delay, quit)
} }
// sendMessages sends a set of messages to the remote peer. // sendMessages sends a set of messages to the remote peer.
@@ -630,9 +630,13 @@ func (m *SyncManager) sendMessages(ctx context.Context, sync bool,
peer lnpeer.Peer, nodeID route.Vertex, msgs ...lnwire.Message) error { peer lnpeer.Peer, nodeID route.Vertex, msgs ...lnwire.Message) error {
for _, msg := range msgs { for _, msg := range msgs {
if err := m.maybeRateLimitMsg(ctx, nodeID, msg); err != nil { err := maybeRateLimitMsg(
ctx, m.rateLimiter, nodeID, msg, m.quit,
)
if err != nil {
return err return err
} }
if err := peer.SendMessageLazy(sync, msg); err != nil { if err := peer.SendMessageLazy(sync, msg); err != nil {
return err return err
} }

View File

@@ -748,7 +748,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
} }
// First message should have no delay as it fits within burst. // First message should have no delay as it fits within burst.
delay1, err := sm.deriveRateLimitReservation(msg) delay1, err := deriveRateLimitReservation(sm.rateLimiter, msg)
require.NoError(t, err) require.NoError(t, err)
require.Equal( require.Equal(
t, time.Duration(0), delay1.Delay(), "first message "+ t, time.Duration(0), delay1.Delay(), "first message "+
@@ -757,7 +757,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
// Second message should have a non-zero delay as the token // Second message should have a non-zero delay as the token
// bucket is now depleted. // bucket is now depleted.
delay2, err := sm.deriveRateLimitReservation(msg) delay2, err := deriveRateLimitReservation(sm.rateLimiter, msg)
require.NoError(t, err) require.NoError(t, err)
require.True( require.True(
t, delay2.Delay() > 0, "second message should have "+ t, delay2.Delay() > 0, "second message should have "+
@@ -766,7 +766,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
// Third message should have an even longer delay since the // Third message should have an even longer delay since the
// token bucket is still refilling at a constant rate. // token bucket is still refilling at a constant rate.
delay3, err := sm.deriveRateLimitReservation(msg) delay3, err := deriveRateLimitReservation(sm.rateLimiter, msg)
require.NoError(t, err) require.NoError(t, err)
require.True(t, delay3.Delay() > delay2.Delay(), "third "+ require.True(t, delay3.Delay() > delay2.Delay(), "third "+
"message should have longer delay than second: %s > %s", "message should have longer delay than second: %s > %s",
@@ -798,7 +798,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
// The error should propagate through // The error should propagate through
// deriveRateLimitReservation. // deriveRateLimitReservation.
_, err := sm.deriveRateLimitReservation(msg) _, err := deriveRateLimitReservation(sm.rateLimiter, msg)
require.Error(t, err) require.Error(t, err)
require.Equal( require.Equal(
t, errMsg, err, "Error should be propagated unchanged", t, errMsg, err, "Error should be propagated unchanged",
@@ -815,7 +815,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
initialMsg := &TestSizeableMessage{ initialMsg := &TestSizeableMessage{
size: uint32(bytesBurst), size: uint32(bytesBurst),
} }
_, err := sm.deriveRateLimitReservation(initialMsg) _, err := deriveRateLimitReservation(sm.rateLimiter, initialMsg)
require.NoError(t, err) require.NoError(t, err)
// Now send two messages of different sizes and compare their // Now send two messages of different sizes and compare their
@@ -828,18 +828,22 @@ func TestDeriveRateLimitReservation(t *testing.T) {
} }
// Send the small message first. // Send the small message first.
smallDelay, err := sm.deriveRateLimitReservation(smallMsg) smallDelay, err := deriveRateLimitReservation(
sm.rateLimiter, smallMsg,
)
require.NoError(t, err) require.NoError(t, err)
// Reset the limiter to the same state, then empty the bucket. // Reset the limiter to the same state, then empty the bucket.
sm.rateLimiter = rate.NewLimiter( sm.rateLimiter = rate.NewLimiter(
rate.Limit(bytesPerSec), int(bytesBurst), rate.Limit(bytesPerSec), int(bytesBurst),
) )
_, err = sm.deriveRateLimitReservation(initialMsg) _, err = deriveRateLimitReservation(sm.rateLimiter, initialMsg)
require.NoError(t, err) require.NoError(t, err)
// Now send the large message. // Now send the large message.
largeDelay, err := sm.deriveRateLimitReservation(largeMsg) largeDelay, err := deriveRateLimitReservation(
sm.rateLimiter, largeMsg,
)
require.NoError(t, err) require.NoError(t, err)
// The large message should have a longer delay than the small // The large message should have a longer delay than the small