mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-02 16:00:30 +02:00
Merge pull request #9573 from ellemouton/checkUpdateStalenessBeforeRateLimit
discovery: obtain channelMtx before doing any DB calls in `handleChannelUpdate`
This commit is contained in:
commit
a5f54d1d6b
@ -2997,6 +2997,12 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg,
|
||||
graphScid = upd.ShortChannelID
|
||||
}
|
||||
|
||||
// We make sure to obtain the mutex for this channel ID before we access
|
||||
// the database. This ensures the state we read from the database has
|
||||
// not changed between this point and when we call UpdateEdge() later.
|
||||
d.channelMtx.Lock(graphScid.ToUint64())
|
||||
defer d.channelMtx.Unlock(graphScid.ToUint64())
|
||||
|
||||
if d.cfg.Graph.IsStaleEdgePolicy(
|
||||
graphScid, timestamp, upd.ChannelFlags,
|
||||
) {
|
||||
@ -3029,14 +3035,6 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg,
|
||||
// Get the node pub key as far since we don't have it in the channel
|
||||
// update announcement message. We'll need this to properly verify the
|
||||
// message's signature.
|
||||
//
|
||||
// We make sure to obtain the mutex for this channel ID before we
|
||||
// access the database. This ensures the state we read from the
|
||||
// database has not changed between this point and when we call
|
||||
// UpdateEdge() later.
|
||||
d.channelMtx.Lock(graphScid.ToUint64())
|
||||
defer d.channelMtx.Unlock(graphScid.ToUint64())
|
||||
|
||||
chanInfo, e1, e2, err := d.cfg.Graph.GetChannelByID(graphScid)
|
||||
switch {
|
||||
// No error, break.
|
||||
|
@ -76,7 +76,9 @@ var (
|
||||
rebroadcastInterval = time.Hour * 1000000
|
||||
)
|
||||
|
||||
// TODO(elle): replace mockGraphSource with testify.Mock.
|
||||
type mockGraphSource struct {
|
||||
t *testing.T
|
||||
bestHeight uint32
|
||||
|
||||
mu sync.Mutex
|
||||
@ -85,15 +87,22 @@ type mockGraphSource struct {
|
||||
edges map[uint64][]models.ChannelEdgePolicy
|
||||
zombies map[uint64][][33]byte
|
||||
chansToReject map[uint64]struct{}
|
||||
|
||||
updateEdgeCount int
|
||||
pauseGetChannelByID chan chan struct{}
|
||||
}
|
||||
|
||||
func newMockRouter(height uint32) *mockGraphSource {
|
||||
func newMockRouter(t *testing.T, height uint32) *mockGraphSource {
|
||||
return &mockGraphSource{
|
||||
bestHeight: height,
|
||||
infos: make(map[uint64]models.ChannelEdgeInfo),
|
||||
edges: make(map[uint64][]models.ChannelEdgePolicy),
|
||||
zombies: make(map[uint64][][33]byte),
|
||||
chansToReject: make(map[uint64]struct{}),
|
||||
t: t,
|
||||
bestHeight: height,
|
||||
infos: make(map[uint64]models.ChannelEdgeInfo),
|
||||
edges: make(
|
||||
map[uint64][]models.ChannelEdgePolicy,
|
||||
),
|
||||
zombies: make(map[uint64][][33]byte),
|
||||
chansToReject: make(map[uint64]struct{}),
|
||||
pauseGetChannelByID: make(chan chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
@ -155,7 +164,10 @@ func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy,
|
||||
_ ...batch.SchedulerOption) error {
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
defer func() {
|
||||
r.updateEdgeCount++
|
||||
r.mu.Unlock()
|
||||
}()
|
||||
|
||||
if len(r.edges[edge.ChannelID]) == 0 {
|
||||
r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy, 2)
|
||||
@ -234,6 +246,18 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
|
||||
*models.ChannelEdgePolicy,
|
||||
*models.ChannelEdgePolicy, error) {
|
||||
|
||||
select {
|
||||
// Check if a pause request channel has been loaded. If one has, then we
|
||||
// wait for it to be closed before continuing.
|
||||
case pauseChan := <-r.pauseGetChannelByID:
|
||||
select {
|
||||
case <-pauseChan:
|
||||
case <-time.After(time.Second * 30):
|
||||
r.t.Fatal("timeout waiting for pause channel")
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
@ -874,7 +898,7 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) (
|
||||
// any p2p functionality, the peer send and switch send,
|
||||
// broadcast functions won't be populated.
|
||||
notifier := newMockNotifier()
|
||||
router := newMockRouter(startHeight)
|
||||
router := newMockRouter(t, startHeight)
|
||||
chain := &lnmock.MockChain{}
|
||||
t.Cleanup(func() {
|
||||
chain.AssertExpectations(t)
|
||||
@ -3977,6 +4001,202 @@ func TestBroadcastAnnsAfterGraphSynced(t *testing.T) {
|
||||
assertBroadcast(chanAnn2, true, true)
|
||||
}
|
||||
|
||||
// TestRateLimitDeDup tests that if we get the same channel update in very
|
||||
// quick succession, then these updates should not be individually considered
|
||||
// in our rate limiting logic.
|
||||
//
|
||||
// NOTE: this only tests the deduplication logic. The main rate limiting logic
|
||||
// is tested by TestRateLimitChannelUpdates.
|
||||
func TestRateLimitDeDup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create our test harness.
|
||||
const blockHeight = 100
|
||||
ctx, err := createTestCtx(t, blockHeight, false)
|
||||
require.NoError(t, err, "can't create context")
|
||||
ctx.gossiper.cfg.RebroadcastInterval = time.Hour
|
||||
|
||||
var findBaseByAliasCount atomic.Int32
|
||||
ctx.gossiper.cfg.FindBaseByAlias = func(alias lnwire.ShortChannelID) (
|
||||
lnwire.ShortChannelID, error) {
|
||||
|
||||
findBaseByAliasCount.Add(1)
|
||||
|
||||
return lnwire.ShortChannelID{}, fmt.Errorf("none")
|
||||
}
|
||||
|
||||
getUpdateEdgeCount := func() int {
|
||||
ctx.router.mu.Lock()
|
||||
defer ctx.router.mu.Unlock()
|
||||
|
||||
return ctx.router.updateEdgeCount
|
||||
}
|
||||
|
||||
// We set the burst to 2 here. The very first update should not count
|
||||
// towards this _and_ any duplicates should also not count towards it.
|
||||
ctx.gossiper.cfg.MaxChannelUpdateBurst = 2
|
||||
ctx.gossiper.cfg.ChannelUpdateInterval = time.Minute
|
||||
|
||||
// The graph should start empty.
|
||||
require.Empty(t, ctx.router.infos)
|
||||
require.Empty(t, ctx.router.edges)
|
||||
|
||||
// We'll create a batch of signed announcements, including updates for
|
||||
// both sides, for a channel and process them. They should all be
|
||||
// forwarded as this is our first time learning about the channel.
|
||||
batch, err := ctx.createRemoteAnnouncements(blockHeight)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodePeer1 := &mockPeer{
|
||||
remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{},
|
||||
}
|
||||
select {
|
||||
case err := <-ctx.gossiper.ProcessRemoteAnnouncement(
|
||||
batch.chanAnn, nodePeer1,
|
||||
):
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("remote announcement not processed")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-ctx.gossiper.ProcessRemoteAnnouncement(
|
||||
batch.chanUpdAnn1, nodePeer1,
|
||||
):
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("remote announcement not processed")
|
||||
}
|
||||
|
||||
nodePeer2 := &mockPeer{
|
||||
remoteKeyPriv2.PubKey(), nil, nil, atomic.Bool{},
|
||||
}
|
||||
select {
|
||||
case err := <-ctx.gossiper.ProcessRemoteAnnouncement(
|
||||
batch.chanUpdAnn2, nodePeer2,
|
||||
):
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("remote announcement not processed")
|
||||
}
|
||||
|
||||
timeout := time.After(2 * trickleDelay)
|
||||
for i := 0; i < 3; i++ {
|
||||
select {
|
||||
case <-ctx.broadcastedMessage:
|
||||
case <-timeout:
|
||||
t.Fatal("expected announcement to be broadcast")
|
||||
}
|
||||
}
|
||||
|
||||
shortChanID := batch.chanAnn.ShortChannelID.ToUint64()
|
||||
require.Contains(t, ctx.router.infos, shortChanID)
|
||||
require.Contains(t, ctx.router.edges, shortChanID)
|
||||
|
||||
// Before we send anymore updates, we want to let our test harness
|
||||
// hang during GetChannelByID so that we can ensure that two threads are
|
||||
// waiting for the chan.
|
||||
pause := make(chan struct{})
|
||||
ctx.router.pauseGetChannelByID <- pause
|
||||
|
||||
// Take note of how many times FindBaseByAlias has been called.
|
||||
// It should be 2 since we have processed two channel updates.
|
||||
require.EqualValues(t, 2, findBaseByAliasCount.Load())
|
||||
|
||||
// The same is expected for the UpdateEdge call.
|
||||
require.EqualValues(t, 2, getUpdateEdgeCount())
|
||||
|
||||
update := *batch.chanUpdAnn1
|
||||
|
||||
// refreshUpdate is a helper that helps us ensure that the update
|
||||
// is not seen as stale or as a keep-alive.
|
||||
refreshUpdate := func() {
|
||||
update.Timestamp++
|
||||
update.BaseFee++
|
||||
require.NoError(t, signUpdate(remoteKeyPriv1, &update))
|
||||
}
|
||||
|
||||
refreshUpdate()
|
||||
|
||||
// Ok, now we will send the same channel update twice in quick
|
||||
// succession. We wait for both to have hit the FindBaseByAlias check
|
||||
// before we un-pause the GetChannelByID call.
|
||||
go func() {
|
||||
ctx.gossiper.ProcessRemoteAnnouncement(&update, nodePeer1)
|
||||
}()
|
||||
go func() {
|
||||
ctx.gossiper.ProcessRemoteAnnouncement(&update, nodePeer1)
|
||||
}()
|
||||
|
||||
// We know that both are being processed once the count for
|
||||
// FindBaseByAlias has increased by 2.
|
||||
err = wait.NoError(func() error {
|
||||
count := findBaseByAliasCount.Load()
|
||||
|
||||
if count != 4 {
|
||||
return fmt.Errorf("expected 4 calls to "+
|
||||
"FindBaseByAlias, got %v", count)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, time.Second*5)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now we can un-pause the thread that grabbed the mutex first.
|
||||
close(pause)
|
||||
|
||||
// Only 1 call should have made it past the staleness check to the
|
||||
// graph's UpdateEdge call.
|
||||
err = wait.NoError(func() error {
|
||||
count := getUpdateEdgeCount()
|
||||
if count != 3 {
|
||||
return fmt.Errorf("expected 3 calls to UpdateEdge, "+
|
||||
"got %v", count)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, time.Second*5)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We'll define a helper to assert whether update was broadcast or not.
|
||||
assertBroadcast := func(shouldBroadcast bool) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case <-ctx.broadcastedMessage:
|
||||
require.True(t, shouldBroadcast)
|
||||
case <-time.After(2 * trickleDelay):
|
||||
require.False(t, shouldBroadcast)
|
||||
}
|
||||
}
|
||||
|
||||
processUpdate := func(msg lnwire.Message, peer lnpeer.Peer) {
|
||||
select {
|
||||
case err := <-ctx.gossiper.ProcessRemoteAnnouncement(
|
||||
msg, peer,
|
||||
):
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("remote announcement not processed")
|
||||
}
|
||||
}
|
||||
|
||||
// Show that the last update was broadcast.
|
||||
assertBroadcast(true)
|
||||
|
||||
// We should be allowed to send another update now since the rate limit
|
||||
// has still not been met.
|
||||
refreshUpdate()
|
||||
processUpdate(&update, nodePeer1)
|
||||
assertBroadcast(true)
|
||||
|
||||
// Our rate limit should be hit now, so a new update should not be
|
||||
// broadcast.
|
||||
refreshUpdate()
|
||||
processUpdate(&update, nodePeer1)
|
||||
assertBroadcast(false)
|
||||
}
|
||||
|
||||
// TestRateLimitChannelUpdates ensures that we properly rate limit incoming
|
||||
// channel updates.
|
||||
func TestRateLimitChannelUpdates(t *testing.T) {
|
||||
|
@ -77,6 +77,11 @@
|
||||
restarts, for details check [this
|
||||
issue](https://github.com/lightningnetwork/lnd/issues/8975#issuecomment-2270528222).
|
||||
|
||||
* [Fix a bug](https://github.com/lightningnetwork/lnd/pull/9573) where
|
||||
processing duplicate ChannelUpdates from different peers in quick succession
|
||||
could lead to our ChannelUpdate rate limiting logic being prematurely
|
||||
triggered.
|
||||
|
||||
# New Features
|
||||
|
||||
* [Support](https://github.com/lightningnetwork/lnd/pull/8390) for
|
||||
|
Loading…
x
Reference in New Issue
Block a user