diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 81c88ee3b..b34c63eb6 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -473,6 +473,9 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, return next, nil }, Clock: clock.NewTestClock(time.Unix(1, 0)), + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, }) if err != nil { t.Fatalf("unable to create router %v", err) diff --git a/routing/router.go b/routing/router.go index 71e72c3d1..9e97291e0 100644 --- a/routing/router.go +++ b/routing/router.go @@ -358,6 +358,10 @@ type Config struct { // Otherwise, we'll only prune the channel when both edges have a very // dated last update. StrictZombiePruning bool + + // IsAlias returns whether a passed ShortChannelID is an alias. This is + // only used for our local channels. + IsAlias func(scid lnwire.ShortChannelID) bool } // EdgeLocator is a struct used to identify a specific edge. @@ -1455,8 +1459,12 @@ func (r *ChannelRouter) processUpdate(msg interface{}, // If AssumeChannelValid is present, then we are unable to // perform any of the expensive checks below, so we'll // short-circuit our path straight to adding the edge to our - // graph. - if r.cfg.AssumeChannelValid { + // graph. If the passed ShortChannelID is an alias, then we'll + // skip validation as it will not map to a legitimate tx. This + // is not a DoS vector as only we can add an alias + // ChannelAnnouncement from the gossiper. + scid := lnwire.NewShortChanIDFromInt(msg.ChannelID) + if r.cfg.AssumeChannelValid || r.cfg.IsAlias(scid) { if err := r.cfg.Graph.AddChannelEdge(msg, op...); err != nil { return fmt.Errorf("unable to add edge: %v", err) } diff --git a/routing/router_test.go b/routing/router_test.go index 4ffee7be7..a5425a3db 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -82,6 +82,9 @@ func (c *testCtx) RestartRouter(t *testing.T) { Control: makeMockControlTower(), ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, }) require.NoError(t, err, "unable to create router") require.NoError(t, router.Start(), "unable to start router") @@ -165,6 +168,9 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, Clock: clock.NewTestClock(time.Unix(1, 0)), AssumeChannelValid: assumeValid, StrictZombiePruning: strictPruning, + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, }) require.NoError(t, err, "unable to create router") require.NoError(t, router.Start(), "unable to start router") @@ -1738,6 +1744,10 @@ func TestWakeUpOnStaleBranch(t *testing.T) { // We'll set the delay to zero to prune immediately. FirstTimePruneDelay: 0, + + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, }) if err != nil { t.Fatalf("unable to create router %v", err) @@ -3485,6 +3495,10 @@ func TestSendMPPaymentSucceed(t *testing.T) { next := atomic.AddUint64(&uniquePaymentID, 1) return next, nil }, + + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, }) require.NoError(t, err, "failed to create router") @@ -3648,6 +3662,10 @@ func TestSendMPPaymentSucceedOnExtraShards(t *testing.T) { next := atomic.AddUint64(&uniquePaymentID, 1) return next, nil }, + + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, }) require.NoError(t, err, "failed to create router") @@ -3856,6 +3874,10 @@ func TestSendMPPaymentFailed(t *testing.T) { next := atomic.AddUint64(&uniquePaymentID, 1) return next, nil }, + + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, }) require.NoError(t, err, "failed to create router") @@ -4056,6 +4078,10 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { next := atomic.AddUint64(&uniquePaymentID, 1) return next, nil }, + + IsAlias: func(scid lnwire.ShortChannelID) bool { + return false + }, }) require.NoError(t, err, "failed to create router") diff --git a/server.go b/server.go index da69c02b0..0f79612b9 100644 --- a/server.go +++ b/server.go @@ -920,6 +920,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, PathFindingConfig: pathFindingConfig, Clock: clock.NewDefaultClock(), StrictZombiePruning: strictPruning, + IsAlias: aliasmgr.IsAlias, }) if err != nil { return nil, fmt.Errorf("can't create router: %v", err)