diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index fb3a28458..f4e02c712 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -111,19 +111,32 @@ type fwdResolution struct { errChan chan error } -// NewInterceptableSwitch returns an instance of InterceptableSwitch. -func NewInterceptableSwitch(s *Switch, cltvRejectDelta uint32, - requireInterceptor bool) *InterceptableSwitch { +// InterceptableSwitchConfig contains the configuration of InterceptableSwitch. +type InterceptableSwitchConfig struct { + // Switch is a reference to the actual switch implementation that + // packets get sent to on resume. + Switch *Switch + // CltvRejectDelta defines the number of blocks before the expiry of the + // htlc where we no longer intercept it and instead cancel it back. + CltvRejectDelta uint32 + + // RequireInterceptor indicates whether processing should block if no + // interceptor is connected. + RequireInterceptor bool +} + +// NewInterceptableSwitch returns an instance of InterceptableSwitch. +func NewInterceptableSwitch(cfg *InterceptableSwitchConfig) *InterceptableSwitch { return &InterceptableSwitch{ - htlcSwitch: s, + htlcSwitch: cfg.Switch, intercepted: make(chan *interceptedPackets), onchainIntercepted: make(chan InterceptedForward), interceptorRegistration: make(chan ForwardInterceptor), holdForwards: make(map[channeldb.CircuitKey]InterceptedForward), resolutionChan: make(chan *fwdResolution), - requireInterceptor: requireInterceptor, - cltvRejectDelta: cltvRejectDelta, + requireInterceptor: cfg.RequireInterceptor, + cltvRejectDelta: cfg.CltvRejectDelta, quit: make(chan struct{}), } diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 61c7514ce..83389a374 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3846,7 +3846,10 @@ func TestSwitchHoldForward(t *testing.T) { interceptedChan: make(chan InterceptedPacket), } switchForwardInterceptor := NewInterceptableSwitch( - s, cltvRejectDelta, false, + &InterceptableSwitchConfig{ + Switch: s, + CltvRejectDelta: cltvRejectDelta, + }, ) require.NoError(t, switchForwardInterceptor.Start()) @@ -4037,7 +4040,13 @@ func TestSwitchHoldForward(t *testing.T) { require.NoError(t, switchForwardInterceptor.Stop()) // Test always-on interception. - switchForwardInterceptor = NewInterceptableSwitch(s, cltvRejectDelta, true) + switchForwardInterceptor = NewInterceptableSwitch( + &InterceptableSwitchConfig{ + Switch: s, + CltvRejectDelta: cltvRejectDelta, + RequireInterceptor: true, + }, + ) require.NoError(t, switchForwardInterceptor.Start()) // Forward a fresh packet. It is expected to be failed immediately, @@ -5329,7 +5338,11 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { t: t, interceptedChan: make(chan InterceptedPacket), } - interceptSwitch := NewInterceptableSwitch(s, 0, false) + interceptSwitch := NewInterceptableSwitch( + &InterceptableSwitchConfig{ + Switch: s, + }, + ) require.NoError(t, interceptSwitch.Start()) interceptSwitch.SetInterceptor(forwardInterceptor.InterceptForwardHtlc) diff --git a/peer/test_utils.go b/peer/test_utils.go index 938c3e51c..1bbde70b1 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -367,7 +367,9 @@ func createTestPeer(t *testing.T, notifier chainntnfs.ChainNotifier, Switch: mockSwitch, ChanActiveTimeout: chanActiveTimeout, InterceptSwitch: htlcswitch.NewInterceptableSwitch( - nil, testCltvRejectDelta, false, + &htlcswitch.InterceptableSwitchConfig{ + CltvRejectDelta: testCltvRejectDelta, + }, ), ChannelDB: dbAlice.ChannelStateDB(), FeeEstimator: estimator, diff --git a/server.go b/server.go index e3b86da6c..50b072357 100644 --- a/server.go +++ b/server.go @@ -667,8 +667,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } s.interceptableSwitch = htlcswitch.NewInterceptableSwitch( - s.htlcSwitch, lncfg.DefaultFinalCltvRejectDelta, - s.cfg.RequireInterceptor, + &htlcswitch.InterceptableSwitchConfig{ + Switch: s.htlcSwitch, + CltvRejectDelta: lncfg.DefaultFinalCltvRejectDelta, + RequireInterceptor: s.cfg.RequireInterceptor, + }, ) s.witnessBeacon = newPreimageBeacon(