mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-11-10 06:07:16 +01:00
discovery: create common helper methods for rate limiter
This allows us to reuse them in the upcoming commits where we introduce a rate limiter to the gossip syncer.
This commit is contained in:
@@ -200,7 +200,7 @@ type SyncManager struct {
|
|||||||
gossipFilterSema chan struct{}
|
gossipFilterSema chan struct{}
|
||||||
|
|
||||||
// rateLimiter dictates the frequency with which we will reply to gossip
|
// rateLimiter dictates the frequency with which we will reply to gossip
|
||||||
// queries from a peer. This is used to delay responses to peers to
|
// queries from all peers. This is used to delay responses to peers to
|
||||||
// prevent DOS vulnerabilities if they are spamming with an unreasonable
|
// prevent DOS vulnerabilities if they are spamming with an unreasonable
|
||||||
// number of queries.
|
// number of queries.
|
||||||
rateLimiter *rate.Limiter
|
rateLimiter *rate.Limiter
|
||||||
@@ -554,8 +554,8 @@ func (m *SyncManager) isPinnedSyncer(s *GossipSyncer) bool {
|
|||||||
|
|
||||||
// deriveRateLimitReservation will take the current message and derive a
|
// deriveRateLimitReservation will take the current message and derive a
|
||||||
// reservation that can be used to wait on the rate limiter.
|
// reservation that can be used to wait on the rate limiter.
|
||||||
func (m *SyncManager) deriveRateLimitReservation(msg lnwire.Message,
|
func deriveRateLimitReservation(rl *rate.Limiter,
|
||||||
) (*rate.Reservation, error) {
|
msg lnwire.Message) (*rate.Reservation, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
msgSize uint32
|
msgSize uint32
|
||||||
@@ -575,12 +575,12 @@ func (m *SyncManager) deriveRateLimitReservation(msg lnwire.Message,
|
|||||||
msgSize = assumedMsgSize
|
msgSize = assumedMsgSize
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.rateLimiter.ReserveN(time.Now(), int(msgSize)), nil
|
return rl.ReserveN(time.Now(), int(msgSize)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// waitMsgDelay takes a delay, and waits until it has finished.
|
// waitMsgDelay takes a delay, and waits until it has finished.
|
||||||
func (m *SyncManager) waitMsgDelay(ctx context.Context, peerPub [33]byte,
|
func waitMsgDelay(ctx context.Context, peerPub [33]byte,
|
||||||
limitReservation *rate.Reservation) error {
|
limitReservation *rate.Reservation, quit <-chan struct{}) error {
|
||||||
|
|
||||||
// If we've already replied a handful of times, we will start to delay
|
// If we've already replied a handful of times, we will start to delay
|
||||||
// responses back to the remote peer. This can help prevent DOS attacks
|
// responses back to the remote peer. This can help prevent DOS attacks
|
||||||
@@ -602,7 +602,7 @@ func (m *SyncManager) waitMsgDelay(ctx context.Context, peerPub [33]byte,
|
|||||||
|
|
||||||
return ErrGossipSyncerExiting
|
return ErrGossipSyncerExiting
|
||||||
|
|
||||||
case <-m.quit:
|
case <-quit:
|
||||||
limitReservation.Cancel()
|
limitReservation.Cancel()
|
||||||
|
|
||||||
return ErrGossipSyncerExiting
|
return ErrGossipSyncerExiting
|
||||||
@@ -614,15 +614,15 @@ func (m *SyncManager) waitMsgDelay(ctx context.Context, peerPub [33]byte,
|
|||||||
|
|
||||||
// maybeRateLimitMsg takes a message, and may wait a period of time to rate
|
// maybeRateLimitMsg takes a message, and may wait a period of time to rate
|
||||||
// limit the msg.
|
// limit the msg.
|
||||||
func (m *SyncManager) maybeRateLimitMsg(ctx context.Context, peerPub [33]byte,
|
func maybeRateLimitMsg(ctx context.Context, rl *rate.Limiter, peerPub [33]byte,
|
||||||
msg lnwire.Message) error {
|
msg lnwire.Message, quit <-chan struct{}) error {
|
||||||
|
|
||||||
delay, err := m.deriveRateLimitReservation(msg)
|
delay, err := deriveRateLimitReservation(rl, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.waitMsgDelay(ctx, peerPub, delay)
|
return waitMsgDelay(ctx, peerPub, delay, quit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendMessages sends a set of messages to the remote peer.
|
// sendMessages sends a set of messages to the remote peer.
|
||||||
@@ -630,9 +630,13 @@ func (m *SyncManager) sendMessages(ctx context.Context, sync bool,
|
|||||||
peer lnpeer.Peer, nodeID route.Vertex, msgs ...lnwire.Message) error {
|
peer lnpeer.Peer, nodeID route.Vertex, msgs ...lnwire.Message) error {
|
||||||
|
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
if err := m.maybeRateLimitMsg(ctx, nodeID, msg); err != nil {
|
err := maybeRateLimitMsg(
|
||||||
|
ctx, m.rateLimiter, nodeID, msg, m.quit,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := peer.SendMessageLazy(sync, msg); err != nil {
|
if err := peer.SendMessageLazy(sync, msg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -748,7 +748,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// First message should have no delay as it fits within burst.
|
// First message should have no delay as it fits within burst.
|
||||||
delay1, err := sm.deriveRateLimitReservation(msg)
|
delay1, err := deriveRateLimitReservation(sm.rateLimiter, msg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(
|
require.Equal(
|
||||||
t, time.Duration(0), delay1.Delay(), "first message "+
|
t, time.Duration(0), delay1.Delay(), "first message "+
|
||||||
@@ -757,7 +757,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
|
|||||||
|
|
||||||
// Second message should have a non-zero delay as the token
|
// Second message should have a non-zero delay as the token
|
||||||
// bucket is now depleted.
|
// bucket is now depleted.
|
||||||
delay2, err := sm.deriveRateLimitReservation(msg)
|
delay2, err := deriveRateLimitReservation(sm.rateLimiter, msg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(
|
require.True(
|
||||||
t, delay2.Delay() > 0, "second message should have "+
|
t, delay2.Delay() > 0, "second message should have "+
|
||||||
@@ -766,7 +766,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
|
|||||||
|
|
||||||
// Third message should have an even longer delay since the
|
// Third message should have an even longer delay since the
|
||||||
// token bucket is still refilling at a constant rate.
|
// token bucket is still refilling at a constant rate.
|
||||||
delay3, err := sm.deriveRateLimitReservation(msg)
|
delay3, err := deriveRateLimitReservation(sm.rateLimiter, msg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.True(t, delay3.Delay() > delay2.Delay(), "third "+
|
require.True(t, delay3.Delay() > delay2.Delay(), "third "+
|
||||||
"message should have longer delay than second: %s > %s",
|
"message should have longer delay than second: %s > %s",
|
||||||
@@ -798,7 +798,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
|
|||||||
|
|
||||||
// The error should propagate through
|
// The error should propagate through
|
||||||
// deriveRateLimitReservation.
|
// deriveRateLimitReservation.
|
||||||
_, err := sm.deriveRateLimitReservation(msg)
|
_, err := deriveRateLimitReservation(sm.rateLimiter, msg)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(
|
require.Equal(
|
||||||
t, errMsg, err, "Error should be propagated unchanged",
|
t, errMsg, err, "Error should be propagated unchanged",
|
||||||
@@ -815,7 +815,7 @@ func TestDeriveRateLimitReservation(t *testing.T) {
|
|||||||
initialMsg := &TestSizeableMessage{
|
initialMsg := &TestSizeableMessage{
|
||||||
size: uint32(bytesBurst),
|
size: uint32(bytesBurst),
|
||||||
}
|
}
|
||||||
_, err := sm.deriveRateLimitReservation(initialMsg)
|
_, err := deriveRateLimitReservation(sm.rateLimiter, initialMsg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Now send two messages of different sizes and compare their
|
// Now send two messages of different sizes and compare their
|
||||||
@@ -828,18 +828,22 @@ func TestDeriveRateLimitReservation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send the small message first.
|
// Send the small message first.
|
||||||
smallDelay, err := sm.deriveRateLimitReservation(smallMsg)
|
smallDelay, err := deriveRateLimitReservation(
|
||||||
|
sm.rateLimiter, smallMsg,
|
||||||
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Reset the limiter to the same state, then empty the bucket.
|
// Reset the limiter to the same state, then empty the bucket.
|
||||||
sm.rateLimiter = rate.NewLimiter(
|
sm.rateLimiter = rate.NewLimiter(
|
||||||
rate.Limit(bytesPerSec), int(bytesBurst),
|
rate.Limit(bytesPerSec), int(bytesBurst),
|
||||||
)
|
)
|
||||||
_, err = sm.deriveRateLimitReservation(initialMsg)
|
_, err = deriveRateLimitReservation(sm.rateLimiter, initialMsg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Now send the large message.
|
// Now send the large message.
|
||||||
largeDelay, err := sm.deriveRateLimitReservation(largeMsg)
|
largeDelay, err := deriveRateLimitReservation(
|
||||||
|
sm.rateLimiter, largeMsg,
|
||||||
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// The large message should have a longer delay than the small
|
// The large message should have a longer delay than the small
|
||||||
|
|||||||
Reference in New Issue
Block a user