Merge pull request #9573 from ellemouton/checkUpdateStalenessBeforeRateLimit

discovery: obtain channelMtx before doing any DB calls in `handleChannelUpdate`
This commit is contained in:
Oliver Gugger 2025-03-07 04:17:00 -06:00 committed by GitHub
commit a5f54d1d6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 239 additions and 16 deletions

View File

@ -2997,6 +2997,12 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg,
graphScid = upd.ShortChannelID
}
// We make sure to obtain the mutex for this channel ID before we access
// the database. This ensures the state we read from the database has
// not changed between this point and when we call UpdateEdge() later.
d.channelMtx.Lock(graphScid.ToUint64())
defer d.channelMtx.Unlock(graphScid.ToUint64())
if d.cfg.Graph.IsStaleEdgePolicy(
graphScid, timestamp, upd.ChannelFlags,
) {
@ -3029,14 +3035,6 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg,
// Get the node pub key as far since we don't have it in the channel
// update announcement message. We'll need this to properly verify the
// message's signature.
//
// We make sure to obtain the mutex for this channel ID before we
// access the database. This ensures the state we read from the
// database has not changed between this point and when we call
// UpdateEdge() later.
d.channelMtx.Lock(graphScid.ToUint64())
defer d.channelMtx.Unlock(graphScid.ToUint64())
chanInfo, e1, e2, err := d.cfg.Graph.GetChannelByID(graphScid)
switch {
// No error, break.

View File

@ -76,7 +76,9 @@ var (
rebroadcastInterval = time.Hour * 1000000
)
// TODO(elle): replace mockGraphSource with testify.Mock.
type mockGraphSource struct {
t *testing.T
bestHeight uint32
mu sync.Mutex
@ -85,15 +87,22 @@ type mockGraphSource struct {
edges map[uint64][]models.ChannelEdgePolicy
zombies map[uint64][][33]byte
chansToReject map[uint64]struct{}
updateEdgeCount int
pauseGetChannelByID chan chan struct{}
}
func newMockRouter(height uint32) *mockGraphSource {
func newMockRouter(t *testing.T, height uint32) *mockGraphSource {
return &mockGraphSource{
bestHeight: height,
infos: make(map[uint64]models.ChannelEdgeInfo),
edges: make(map[uint64][]models.ChannelEdgePolicy),
zombies: make(map[uint64][][33]byte),
chansToReject: make(map[uint64]struct{}),
t: t,
bestHeight: height,
infos: make(map[uint64]models.ChannelEdgeInfo),
edges: make(
map[uint64][]models.ChannelEdgePolicy,
),
zombies: make(map[uint64][][33]byte),
chansToReject: make(map[uint64]struct{}),
pauseGetChannelByID: make(chan chan struct{}, 1),
}
}
@ -155,7 +164,10 @@ func (r *mockGraphSource) UpdateEdge(edge *models.ChannelEdgePolicy,
_ ...batch.SchedulerOption) error {
r.mu.Lock()
defer r.mu.Unlock()
defer func() {
r.updateEdgeCount++
r.mu.Unlock()
}()
if len(r.edges[edge.ChannelID]) == 0 {
r.edges[edge.ChannelID] = make([]models.ChannelEdgePolicy, 2)
@ -234,6 +246,18 @@ func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) (
*models.ChannelEdgePolicy,
*models.ChannelEdgePolicy, error) {
select {
// Check if a pause request channel has been loaded. If one has, then we
// wait for it to be closed before continuing.
case pauseChan := <-r.pauseGetChannelByID:
select {
case <-pauseChan:
case <-time.After(time.Second * 30):
r.t.Fatal("timeout waiting for pause channel")
}
default:
}
r.mu.Lock()
defer r.mu.Unlock()
@ -874,7 +898,7 @@ func createTestCtx(t *testing.T, startHeight uint32, isChanPeer bool) (
// any p2p functionality, the peer send and switch send,
// broadcast functions won't be populated.
notifier := newMockNotifier()
router := newMockRouter(startHeight)
router := newMockRouter(t, startHeight)
chain := &lnmock.MockChain{}
t.Cleanup(func() {
chain.AssertExpectations(t)
@ -3977,6 +4001,202 @@ func TestBroadcastAnnsAfterGraphSynced(t *testing.T) {
assertBroadcast(chanAnn2, true, true)
}
// TestRateLimitDeDup tests that if we get the same channel update in very
// quick succession, then these updates should not be individually considered
// in our rate limiting logic.
//
// NOTE: this only tests the deduplication logic. The main rate limiting logic
// is tested by TestRateLimitChannelUpdates.
func TestRateLimitDeDup(t *testing.T) {
t.Parallel()
// Create our test harness.
const blockHeight = 100
ctx, err := createTestCtx(t, blockHeight, false)
require.NoError(t, err, "can't create context")
ctx.gossiper.cfg.RebroadcastInterval = time.Hour
var findBaseByAliasCount atomic.Int32
ctx.gossiper.cfg.FindBaseByAlias = func(alias lnwire.ShortChannelID) (
lnwire.ShortChannelID, error) {
findBaseByAliasCount.Add(1)
return lnwire.ShortChannelID{}, fmt.Errorf("none")
}
getUpdateEdgeCount := func() int {
ctx.router.mu.Lock()
defer ctx.router.mu.Unlock()
return ctx.router.updateEdgeCount
}
// We set the burst to 2 here. The very first update should not count
// towards this _and_ any duplicates should also not count towards it.
ctx.gossiper.cfg.MaxChannelUpdateBurst = 2
ctx.gossiper.cfg.ChannelUpdateInterval = time.Minute
// The graph should start empty.
require.Empty(t, ctx.router.infos)
require.Empty(t, ctx.router.edges)
// We'll create a batch of signed announcements, including updates for
// both sides, for a channel and process them. They should all be
// forwarded as this is our first time learning about the channel.
batch, err := ctx.createRemoteAnnouncements(blockHeight)
require.NoError(t, err)
nodePeer1 := &mockPeer{
remoteKeyPriv1.PubKey(), nil, nil, atomic.Bool{},
}
select {
case err := <-ctx.gossiper.ProcessRemoteAnnouncement(
batch.chanAnn, nodePeer1,
):
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("remote announcement not processed")
}
select {
case err := <-ctx.gossiper.ProcessRemoteAnnouncement(
batch.chanUpdAnn1, nodePeer1,
):
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("remote announcement not processed")
}
nodePeer2 := &mockPeer{
remoteKeyPriv2.PubKey(), nil, nil, atomic.Bool{},
}
select {
case err := <-ctx.gossiper.ProcessRemoteAnnouncement(
batch.chanUpdAnn2, nodePeer2,
):
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("remote announcement not processed")
}
timeout := time.After(2 * trickleDelay)
for i := 0; i < 3; i++ {
select {
case <-ctx.broadcastedMessage:
case <-timeout:
t.Fatal("expected announcement to be broadcast")
}
}
shortChanID := batch.chanAnn.ShortChannelID.ToUint64()
require.Contains(t, ctx.router.infos, shortChanID)
require.Contains(t, ctx.router.edges, shortChanID)
// Before we send anymore updates, we want to let our test harness
// hang during GetChannelByID so that we can ensure that two threads are
// waiting for the chan.
pause := make(chan struct{})
ctx.router.pauseGetChannelByID <- pause
// Take note of how many times FindBaseByAlias has been called.
// It should be 2 since we have processed two channel updates.
require.EqualValues(t, 2, findBaseByAliasCount.Load())
// The same is expected for the UpdateEdge call.
require.EqualValues(t, 2, getUpdateEdgeCount())
update := *batch.chanUpdAnn1
// refreshUpdate is a helper that helps us ensure that the update
// is not seen as stale or as a keep-alive.
refreshUpdate := func() {
update.Timestamp++
update.BaseFee++
require.NoError(t, signUpdate(remoteKeyPriv1, &update))
}
refreshUpdate()
// Ok, now we will send the same channel update twice in quick
// succession. We wait for both to have hit the FindBaseByAlias check
// before we un-pause the GetChannelByID call.
go func() {
ctx.gossiper.ProcessRemoteAnnouncement(&update, nodePeer1)
}()
go func() {
ctx.gossiper.ProcessRemoteAnnouncement(&update, nodePeer1)
}()
// We know that both are being processed once the count for
// FindBaseByAlias has increased by 2.
err = wait.NoError(func() error {
count := findBaseByAliasCount.Load()
if count != 4 {
return fmt.Errorf("expected 4 calls to "+
"FindBaseByAlias, got %v", count)
}
return nil
}, time.Second*5)
require.NoError(t, err)
// Now we can un-pause the thread that grabbed the mutex first.
close(pause)
// Only 1 call should have made it past the staleness check to the
// graph's UpdateEdge call.
err = wait.NoError(func() error {
count := getUpdateEdgeCount()
if count != 3 {
return fmt.Errorf("expected 3 calls to UpdateEdge, "+
"got %v", count)
}
return nil
}, time.Second*5)
require.NoError(t, err)
// We'll define a helper to assert whether update was broadcast or not.
assertBroadcast := func(shouldBroadcast bool) {
t.Helper()
select {
case <-ctx.broadcastedMessage:
require.True(t, shouldBroadcast)
case <-time.After(2 * trickleDelay):
require.False(t, shouldBroadcast)
}
}
processUpdate := func(msg lnwire.Message, peer lnpeer.Peer) {
select {
case err := <-ctx.gossiper.ProcessRemoteAnnouncement(
msg, peer,
):
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("remote announcement not processed")
}
}
// Show that the last update was broadcast.
assertBroadcast(true)
// We should be allowed to send another update now since the rate limit
// has still not been met.
refreshUpdate()
processUpdate(&update, nodePeer1)
assertBroadcast(true)
// Our rate limit should be hit now, so a new update should not be
// broadcast.
refreshUpdate()
processUpdate(&update, nodePeer1)
assertBroadcast(false)
}
// TestRateLimitChannelUpdates ensures that we properly rate limit incoming
// channel updates.
func TestRateLimitChannelUpdates(t *testing.T) {

View File

@ -77,6 +77,11 @@
restarts, for details check [this
issue](https://github.com/lightningnetwork/lnd/issues/8975#issuecomment-2270528222).
* [Fix a bug](https://github.com/lightningnetwork/lnd/pull/9573) where
processing duplicate ChannelUpdates from different peers in quick succession
could lead to our ChannelUpdate rate limiting logic being prematurely
triggered.
# New Features
* [Support](https://github.com/lightningnetwork/lnd/pull/8390) for