Merge pull request #6232 from bottlepay/require-interceptor

htlcswitch: add an always on mode to htlc interceptor
This commit is contained in:
Oliver Gugger 2022-03-19 10:47:22 +01:00 committed by GitHub
commit 57840bba36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 570 additions and 251 deletions

View File

@ -354,6 +354,10 @@ type Config struct {
RejectHTLC bool `long:"rejecthtlc" description:"If true, lnd will not forward any HTLCs that are meant as onward payments. This option will still allow lnd to send HTLCs and receive HTLCs but lnd won't be used as a hop."`
// RequireInterceptor determines whether the HTLC interceptor is
// registered regardless of whether the RPC is called or not.
RequireInterceptor bool `long:"requireinterceptor" description:"Whether to always intercept HTLCs, even if no stream is attached"`
StaggerInitialReconnect bool `long:"stagger-initial-reconnect" description:"If true, will apply a randomized staggering between 0s and 30s when reconnecting to persistent peers on startup. The first 10 reconnections will be attempted instantly, regardless of the flag's value"`
MaxOutgoingCltvExpiry uint32 `long:"max-cltv-expiry" description:"The maximum number of blocks funds could be locked up for when forwarding payments."`

View File

@ -109,6 +109,10 @@
change, it allows encrypted failure messages to be returned to the sender.
Additionally it is possible to signal a malformed htlc.
* Add an [always on](https://github.com/lightningnetwork/lnd/pull/6232) mode to
the HTLC interceptor API. This enables interception applications where every
packet must be intercepted.
## Database
* [Add ForAll implementation for etcd to speed up

View File

@ -30,63 +30,289 @@ var (
// Settle - routes UpdateFulfillHTLC to the originating link.
// Fail - routes UpdateFailHTLC to the originating link.
type InterceptableSwitch struct {
sync.RWMutex
// htlcSwitch is the underline switch
htlcSwitch *Switch
// fwdInterceptor is the callback that is called for each forward of
// an incoming htlc. It should return true if it is interested in handling
// it.
fwdInterceptor ForwardInterceptor
// intercepted is where we stream all intercepted packets coming from
// the switch.
intercepted chan *interceptedPackets
// resolutionChan is where we stream all responses coming from the
// interceptor client.
resolutionChan chan *fwdResolution
// interceptorRegistration is a channel that we use to synchronize
// client connect and disconnect.
interceptorRegistration chan ForwardInterceptor
// requireInterceptor indicates whether processing should block if no
// interceptor is connected.
requireInterceptor bool
// interceptor is the handler for intercepted packets.
interceptor ForwardInterceptor
// holdForwards keeps track of outstanding intercepted forwards.
holdForwards map[channeldb.CircuitKey]InterceptedForward
wg sync.WaitGroup
quit chan struct{}
}
type interceptedPackets struct {
packets []*htlcPacket
linkQuit chan struct{}
isReplay bool
}
// FwdAction defines the various resolution types.
type FwdAction int
const (
// FwdActionResume forwards the intercepted packet to the switch.
FwdActionResume FwdAction = iota
// FwdActionSettle settles the intercepted packet with a preimage.
FwdActionSettle
// FwdActionFail fails the intercepted packet back to the sender.
FwdActionFail
)
// FwdResolution defines the action to be taken on an intercepted packet.
type FwdResolution struct {
// Key is the incoming circuit key of the htlc.
Key channeldb.CircuitKey
// Action is the action to take on the intercepted htlc.
Action FwdAction
// Preimage is the preimage that is to be used for settling if Action is
// FwdActionSettle.
Preimage lntypes.Preimage
// FailureMessage is the encrypted failure message that is to be passed
// back to the sender if action is FwdActionFail.
FailureMessage []byte
// FailureCode is the failure code that is to be passed back to the
// sender if action is FwdActionFail.
FailureCode lnwire.FailCode
}
type fwdResolution struct {
resolution *FwdResolution
errChan chan error
}
// NewInterceptableSwitch returns an instance of InterceptableSwitch.
func NewInterceptableSwitch(s *Switch) *InterceptableSwitch {
return &InterceptableSwitch{htlcSwitch: s}
func NewInterceptableSwitch(s *Switch,
requireInterceptor bool) *InterceptableSwitch {
return &InterceptableSwitch{
htlcSwitch: s,
intercepted: make(chan *interceptedPackets),
interceptorRegistration: make(chan ForwardInterceptor),
holdForwards: make(map[channeldb.CircuitKey]InterceptedForward),
resolutionChan: make(chan *fwdResolution),
requireInterceptor: requireInterceptor,
quit: make(chan struct{}),
}
}
// SetInterceptor sets the ForwardInterceptor to be used.
// SetInterceptor sets the ForwardInterceptor to be used. A nil argument
// unregisters the current interceptor.
func (s *InterceptableSwitch) SetInterceptor(
interceptor ForwardInterceptor) {
s.Lock()
defer s.Unlock()
s.fwdInterceptor = interceptor
// Synchronize setting the handler with the main loop to prevent race
// conditions.
select {
case s.interceptorRegistration <- interceptor:
case <-s.quit:
}
}
// ForwardPackets attempts to forward the batch of htlcs through the
// switch, any failed packets will be returned to the provided
// ChannelLink. The link's quit signal should be provided to allow
// cancellation of forwarding during link shutdown.
func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{},
packets ...*htlcPacket) error {
func (s *InterceptableSwitch) Start() error {
s.wg.Add(1)
go func() {
defer s.wg.Done()
var interceptor ForwardInterceptor
s.Lock()
interceptor = s.fwdInterceptor
s.Unlock()
s.run()
}()
// Optimize for the case we don't have an interceptor.
if interceptor == nil {
return s.htlcSwitch.ForwardPackets(linkQuit, packets...)
}
return nil
}
var notIntercepted []*htlcPacket
for _, p := range packets {
if !s.interceptForward(p, interceptor, linkQuit) {
notIntercepted = append(notIntercepted, p)
func (s *InterceptableSwitch) Stop() error {
close(s.quit)
s.wg.Wait()
return nil
}
func (s *InterceptableSwitch) run() {
for {
select {
// An interceptor registration or de-registration came in.
case interceptor := <-s.interceptorRegistration:
s.setInterceptor(interceptor)
case packets := <-s.intercepted:
var notIntercepted []*htlcPacket
for _, p := range packets.packets {
if !s.interceptForward(p, packets.isReplay) {
notIntercepted = append(
notIntercepted, p,
)
}
}
err := s.htlcSwitch.ForwardPackets(
packets.linkQuit, notIntercepted...,
)
if err != nil {
log.Errorf("Cannot forward packets: %v", err)
}
case res := <-s.resolutionChan:
res.errChan <- s.resolve(res.resolution)
case <-s.quit:
return
}
}
return s.htlcSwitch.ForwardPackets(linkQuit, notIntercepted...)
}
func (s *InterceptableSwitch) sendForward(fwd InterceptedForward) {
err := s.interceptor(fwd.Packet())
if err != nil {
// Only log the error. If we couldn't send the packet, we assume
// that the interceptor will reconnect so that we can retry.
log.Debugf("Interceptor cannot handle forward: %v", err)
}
}
// interceptForward checks if there is any external interceptor interested in
// this packet. Currently only htlc type of UpdateAddHTLC that are forwarded
// are being checked for interception. It can be extended in the future given
// the right use case.
func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
s.interceptor = interceptor
// Replay all currently held htlcs. When an interceptor is not required,
// there may be none because they've been cleared after the previous
// disconnect.
if interceptor != nil {
log.Debugf("Interceptor connected")
for _, fwd := range s.holdForwards {
s.sendForward(fwd)
}
return
}
// The interceptor disconnects. If an interceptor is required, keep the
// held htlcs.
if s.requireInterceptor {
log.Infof("Interceptor disconnected, retaining held packets")
return
}
// Interceptor is not required. Release held forwards.
log.Infof("Interceptor disconnected, resolving held packets")
for _, fwd := range s.holdForwards {
if err := fwd.Resume(); err != nil {
log.Errorf("Failed to resume hold forward %v", err)
}
}
s.holdForwards = make(map[channeldb.CircuitKey]InterceptedForward)
}
func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
intercepted, ok := s.holdForwards[res.Key]
if !ok {
return fmt.Errorf("fwd %v not found", res.Key)
}
delete(s.holdForwards, res.Key)
switch res.Action {
case FwdActionResume:
return intercepted.Resume()
case FwdActionSettle:
return intercepted.Settle(res.Preimage)
case FwdActionFail:
if len(res.FailureMessage) > 0 {
return intercepted.Fail(res.FailureMessage)
}
return intercepted.FailWithCode(res.FailureCode)
default:
return fmt.Errorf("unrecognized action %v", res.Action)
}
}
// Resolve resolves an intercepted packet.
func (s *InterceptableSwitch) Resolve(res *FwdResolution) error {
internalRes := &fwdResolution{
resolution: res,
errChan: make(chan error, 1),
}
select {
case s.resolutionChan <- internalRes:
case <-s.quit:
return errors.New("switch shutting down")
}
select {
case err := <-internalRes.errChan:
return err
case <-s.quit:
return errors.New("switch shutting down")
}
}
// ForwardPackets attempts to forward the batch of htlcs to a connected
// interceptor. If the interceptor signals the resume action, the htlcs are
// forwarded to the switch. The link's quit signal should be provided to allow
// cancellation of forwarding during link shutdown.
func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, isReplay bool,
packets ...*htlcPacket) error {
// Synchronize with the main event loop. This should be light in the
// case where there is no interceptor.
select {
case s.intercepted <- &interceptedPackets{
packets: packets,
linkQuit: linkQuit,
isReplay: isReplay,
}:
case <-linkQuit:
log.Debugf("Forward cancelled because link quit")
case <-s.quit:
return errors.New("interceptable switch quit")
}
return nil
}
// interceptForward forwards the packet to the external interceptor after
// checking the interception criteria.
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
interceptor ForwardInterceptor, linkQuit chan struct{}) bool {
isReplay bool) bool {
// Process normally if an interceptor is not required and not
// registered.
if !s.requireInterceptor && s.interceptor == nil {
return false
}
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
@ -95,15 +321,50 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
return false
}
inKey := channeldb.CircuitKey{
ChanID: packet.incomingChanID,
HtlcID: packet.incomingHTLCID,
}
// Ignore already held htlcs.
if _, ok := s.holdForwards[inKey]; ok {
return true
}
intercepted := &interceptedForward{
linkQuit: linkQuit,
htlc: htlc,
packet: packet,
htlcSwitch: s.htlcSwitch,
}
// If this htlc was intercepted, don't handle the forward.
return interceptor(intercepted)
if s.interceptor == nil && !isReplay {
// There is no interceptor registered, we are in
// interceptor-required mode, and this is a new packet
//
// Because the interceptor has never seen this packet
// yet, it is still safe to fail back. This limits the
// backlog of htlcs when the interceptor is down.
err := intercepted.FailWithCode(
lnwire.CodeTemporaryChannelFailure,
)
if err != nil {
log.Errorf("Cannot fail packet: %v", err)
}
return true
}
s.holdForwards[inKey] = intercepted
// If there is no interceptor registered, we must be in
// interceptor-required mode. The packet is kept in the queue
// until the interceptor registers itself.
if s.interceptor != nil {
s.sendForward(intercepted)
}
return true
default:
return false
}
@ -113,7 +374,6 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
// It is passed from the switch to external interceptors that are interested
// in holding forwards and resolve them manually.
type interceptedForward struct {
linkQuit chan struct{}
htlc *lnwire.UpdateAddHTLC
packet *htlcPacket
htlcSwitch *Switch
@ -139,10 +399,12 @@ func (f *interceptedForward) Packet() InterceptedPacket {
// Resume resumes the default behavior as if the packet was not intercepted.
func (f *interceptedForward) Resume() error {
return f.htlcSwitch.ForwardPackets(f.linkQuit, f.packet)
// Forward to the switch. A link quit channel isn't needed, because we
// are on a different thread now.
return f.htlcSwitch.ForwardPackets(nil, f.packet)
}
// Fail notifies the intention to fail an existing hold forward with an
// Fail notifies the intention to Fail an existing hold forward with an
// encrypted failure reason.
func (f *interceptedForward) Fail(reason []byte) error {
obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason)

View File

@ -234,6 +234,9 @@ type TowerClient interface {
type InterceptableHtlcForwarder interface {
// SetInterceptor sets a ForwardInterceptor.
SetInterceptor(interceptor ForwardInterceptor)
// Resolve resolves an intercepted packet.
Resolve(res *FwdResolution) error
}
// ForwardInterceptor is a function that is invoked from the switch for every
@ -242,7 +245,7 @@ type InterceptableHtlcForwarder interface {
// to resolve it manually later in case it is held.
// The return value indicates if this handler will take control of this forward
// and resolve it later or let the switch execute its default behavior.
type ForwardInterceptor func(InterceptedForward) bool
type ForwardInterceptor func(InterceptedPacket) error
// InterceptedPacket contains the relevant information for the interceptor about
// an htlc.

View File

@ -141,7 +141,7 @@ type ChannelLinkConfig struct {
// switch. The function returns and error in case it fails to send one or
// more packets. The link's quit signal should be provided to allow
// cancellation of forwarding during link shutdown.
ForwardPackets func(chan struct{}, ...*htlcPacket) error
ForwardPackets func(chan struct{}, bool, ...*htlcPacket) error
// DecodeHopIterators facilitates batched decoding of HTLC Sphinx onion
// blobs, which are then used to inform how to forward an HTLC.
@ -1720,7 +1720,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
l.uncommittedPreimages = append(l.uncommittedPreimages, pre)
// Pipeline this settle, send it to the switch.
go l.forwardBatch(settlePacket)
go l.forwardBatch(false, settlePacket)
case *lnwire.UpdateFailMalformedHTLC:
// Convert the failure type encoded within the HTLC fail
@ -2744,7 +2744,7 @@ func (l *channelLink) processRemoteSettleFails(fwdPkg *channeldb.FwdPkg,
// Only spawn the task forward packets we have a non-zero number.
if len(switchPackets) > 0 {
go l.forwardBatch(switchPackets...)
go l.forwardBatch(false, switchPackets...)
}
}
@ -3043,14 +3043,17 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
return
}
l.log.Debugf("forwarding %d packets to switch", len(switchPackets))
replay := fwdPkg.State != channeldb.FwdStateLockedIn
l.log.Debugf("forwarding %d packets to switch: replay=%v",
len(switchPackets), replay)
// NOTE: This call is made synchronous so that we ensure all circuits
// are committed in the exact order that they are processed in the link.
// Failing to do this could cause reorderings/gaps in the range of
// opened circuits, which violates assumptions made by the circuit
// trimming.
l.forwardBatch(switchPackets...)
l.forwardBatch(replay, switchPackets...)
}
// processExitHop handles an htlc for which this link is the exit hop. It
@ -3184,7 +3187,7 @@ func (l *channelLink) settleHTLC(preimage lntypes.Preimage,
// forwardBatch forwards the given htlcPackets to the switch, and waits on the
// err chan for the individual responses. This method is intended to be spawned
// as a goroutine so the responses can be handled in the background.
func (l *channelLink) forwardBatch(packets ...*htlcPacket) {
func (l *channelLink) forwardBatch(replay bool, packets ...*htlcPacket) {
// Don't forward packets for which we already have a response in our
// mailbox. This could happen if a packet fails and is buffered in the
// mailbox, and the incoming link flaps.
@ -3197,7 +3200,8 @@ func (l *channelLink) forwardBatch(packets ...*htlcPacket) {
filteredPkts = append(filteredPkts, pkt)
}
if err := l.cfg.ForwardPackets(l.quit, filteredPkts...); err != nil {
err := l.cfg.ForwardPackets(l.quit, replay, filteredPkts...)
if err != nil {
log.Errorf("Unhandled error while reforwarding htlc "+
"settle/fail over htlcswitch: %v", err)
}

View File

@ -1940,12 +1940,14 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
// the firing via force feeding.
bticker := ticker.NewForce(time.Hour)
aliceCfg := ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
Switch: aliceSwitch,
BestHeight: aliceSwitch.BestHeight,
Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: aliceSwitch.ForwardPackets,
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
Switch: aliceSwitch,
BestHeight: aliceSwitch.BestHeight,
Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error {
return aliceSwitch.ForwardPackets(linkQuit, packets...)
},
DecodeHopIterators: decoder.DecodeHopIterators,
ExtractErrorEncrypter: func(*btcec.PublicKey) (
hop.ErrorEncrypter, lnwire.FailCode) {
@ -4491,12 +4493,14 @@ func (h *persistentLinkHarness) restartLink(
// the firing via force feeding.
bticker := ticker.NewForce(time.Hour)
aliceCfg := ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
Switch: aliceSwitch,
BestHeight: aliceSwitch.BestHeight,
Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: aliceSwitch.ForwardPackets,
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
Switch: aliceSwitch,
BestHeight: aliceSwitch.BestHeight,
Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error {
return aliceSwitch.ForwardPackets(linkQuit, packets...)
},
DecodeHopIterators: decoder.DecodeHopIterators,
ExtractErrorEncrypter: func(*btcec.PublicKey) (
hop.ErrorEncrypter, lnwire.FailCode) {
@ -6694,7 +6698,7 @@ func TestPipelineSettle(t *testing.T) {
// erroneously forwarded. If the forwardChan is closed before the last
// step, then the test will fail.
forwardChan := make(chan struct{})
fwdPkts := func(c chan struct{}, hp ...*htlcPacket) error {
fwdPkts := func(c chan struct{}, _ bool, hp ...*htlcPacket) error {
close(forwardChan)
return nil
}

View File

@ -3140,32 +3140,29 @@ func getThreeHopEvents(channels *clusterChannels, htlcID uint64,
}
type mockForwardInterceptor struct {
intercepted InterceptedForward
t *testing.T
interceptedChan chan InterceptedPacket
}
func (m *mockForwardInterceptor) InterceptForwardHtlc(
intercepted InterceptedForward) bool {
intercepted InterceptedPacket) error {
m.intercepted = intercepted
return true
m.interceptedChan <- intercepted
return nil
}
func (m *mockForwardInterceptor) settle(preimage lntypes.Preimage) error {
return m.intercepted.Settle(preimage)
}
func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket {
select {
case p := <-m.interceptedChan:
return p
func (m *mockForwardInterceptor) fail(reason []byte) error {
return m.intercepted.Fail(reason)
}
case <-time.After(time.Second):
require.Fail(m.t, "timeout")
func (m *mockForwardInterceptor) failWithCode(
code lnwire.FailCode) error {
return m.intercepted.FailWithCode(code)
}
func (m *mockForwardInterceptor) resume() error {
return m.intercepted.Resume()
return InterceptedPacket{}
}
}
func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
@ -3272,22 +3269,28 @@ func TestSwitchHoldForward(t *testing.T) {
},
}
forwardInterceptor := &mockForwardInterceptor{}
switchForwardInterceptor := NewInterceptableSwitch(s)
forwardInterceptor := &mockForwardInterceptor{
t: t,
interceptedChan: make(chan InterceptedPacket),
}
switchForwardInterceptor := NewInterceptableSwitch(s, false)
require.NoError(t, switchForwardInterceptor.Start())
switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
linkQuit := make(chan struct{})
// Test resume a hold forward
// Test resume a hold forward.
assertNumCircuits(t, s, 0, 0)
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket)
require.NoError(t, err)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
if err := forwardInterceptor.resume(); err != nil {
t.Fatalf("failed to resume forward")
}
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
Action: FwdActionResume,
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
}))
assertOutgoingLinkReceive(t, bobChannelLink, true)
assertNumCircuits(t, s, 1, 1)
@ -3300,35 +3303,72 @@ func TestSwitchHoldForward(t *testing.T) {
PaymentPreimage: preimage,
},
}
if err := switchForwardInterceptor.ForwardPackets(linkQuit, settle); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, settle)
require.NoError(t, err)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
// Test resume a hold forward after disconnection.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, ogPacket,
))
// Wait until the packet is offered to the interceptor.
_ = forwardInterceptor.getIntercepted()
// No forward expected yet.
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
// Disconnect should resume the forwarding.
switchForwardInterceptor.SetInterceptor(nil)
assertOutgoingLinkReceive(t, bobChannelLink, true)
assertNumCircuits(t, s, 1, 1)
// Settle the htlc to close the circuit.
settle.outgoingHTLCID = 1
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, settle,
))
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
// Test failing a hold forward
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
switchForwardInterceptor.SetInterceptor(
forwardInterceptor.InterceptForwardHtlc,
)
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, ogPacket,
))
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
if err := forwardInterceptor.fail(nil); err != nil {
t.Fatalf("failed to cancel forward %v", err)
}
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
Action: FwdActionFail,
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
FailureCode: lnwire.CodeTemporaryChannelFailure,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
// Test failing a hold forward with a failure message.
require.NoError(t,
switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket),
switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket),
)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
reason := lnwire.OpaqueReason([]byte{1, 2, 3})
require.NoError(t, forwardInterceptor.fail(reason))
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
Action: FwdActionFail,
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
FailureMessage: reason,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
packet := assertOutgoingLinkReceive(t, aliceChannelLink, true)
@ -3338,14 +3378,18 @@ func TestSwitchHoldForward(t *testing.T) {
assertNumCircuits(t, s, 0, 0)
// Test failing a hold forward with a malformed htlc failure.
err = switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket)
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket)
require.NoError(t, err)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
code := lnwire.CodeInvalidOnionKey
require.NoError(t, forwardInterceptor.failWithCode(code))
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
Action: FwdActionFail,
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
FailureCode: code,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
packet = assertOutgoingLinkReceive(t, aliceChannelLink, true)
@ -3363,18 +3407,89 @@ func TestSwitchHoldForward(t *testing.T) {
assertNumCircuits(t, s, 0, 0)
// Test settling a hold forward
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, ogPacket,
))
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
if err := forwardInterceptor.settle(preimage); err != nil {
t.Fatal("failed to cancel forward")
}
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
Action: FwdActionSettle,
Preimage: preimage,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
require.NoError(t, switchForwardInterceptor.Stop())
// Test always-on interception.
switchForwardInterceptor = NewInterceptableSwitch(s, true)
require.NoError(t, switchForwardInterceptor.Start())
// Forward a fresh packet. It is expected to be failed immediately,
// because there is no interceptor registered.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, ogPacket,
))
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
// Forward a replayed packet. It is expected to be held until the
// interceptor connects. To continue the test, it needs to be ran in a
// goroutine.
errChan := make(chan error)
go func() {
errChan <- switchForwardInterceptor.ForwardPackets(
linkQuit, true, ogPacket,
)
}()
// Assert that nothing is forward to the switch.
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertNumCircuits(t, s, 0, 0)
// Register an interceptor.
switchForwardInterceptor.SetInterceptor(
forwardInterceptor.InterceptForwardHtlc,
)
// Expect the ForwardPackets call to unblock.
require.NoError(t, <-errChan)
// Now expect the queued packet to come through.
forwardInterceptor.getIntercepted()
// Disconnect and reconnect interceptor.
switchForwardInterceptor.SetInterceptor(nil)
switchForwardInterceptor.SetInterceptor(
forwardInterceptor.InterceptForwardHtlc,
)
// A replay of the held packet is expected.
intercepted := forwardInterceptor.getIntercepted()
// Settle the packet.
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
Key: intercepted.IncomingCircuit,
Action: FwdActionSettle,
Preimage: preimage,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
require.NoError(t, switchForwardInterceptor.Stop())
select {
case <-forwardInterceptor.interceptedChan:
require.Fail(t, "unexpected interception")
default:
}
}
// TestSwitchDustForwarding tests that the switch properly fails HTLC's which

View File

@ -1135,12 +1135,14 @@ func (h *hopNetwork) createChannelLink(server, peer *mockServer,
link := NewChannelLink(
ChannelLinkConfig{
Switch: server.htlcSwitch,
BestHeight: server.htlcSwitch.BestHeight,
FwrdingPolicy: h.globalPolicy,
Peer: peer,
Circuits: server.htlcSwitch.CircuitModifier(),
ForwardPackets: server.htlcSwitch.ForwardPackets,
Switch: server.htlcSwitch,
BestHeight: server.htlcSwitch.BestHeight,
FwrdingPolicy: h.globalPolicy,
Peer: peer,
Circuits: server.htlcSwitch.CircuitModifier(),
ForwardPackets: func(linkQuit chan struct{}, _ bool, packets ...*htlcPacket) error {
return server.htlcSwitch.ForwardPackets(linkQuit, packets...)
},
DecodeHopIterators: decoder.DecodeHopIterators,
ExtractErrorEncrypter: func(*btcec.PublicKey) (
hop.ErrorEncrypter, lnwire.FailCode) {

View File

@ -2,7 +2,6 @@ package routerrpc
import (
"errors"
"sync"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch"
@ -27,36 +26,19 @@ var (
// interceptor streaming session.
// It is created when the stream opens and disconnects when the stream closes.
type forwardInterceptor struct {
// server is the Server reference
server *Server
// holdForwards is a map of current hold forwards and their corresponding
// ForwardResolver.
holdForwards map[channeldb.CircuitKey]htlcswitch.InterceptedForward
// stream is the bidirectional RPC stream
stream Router_HtlcInterceptorServer
// quit is a channel that is closed when this forwardInterceptor is shutting
// down.
quit chan struct{}
// intercepted is where we stream all intercepted packets coming from
// the switch.
intercepted chan htlcswitch.InterceptedForward
wg sync.WaitGroup
htlcSwitch htlcswitch.InterceptableHtlcForwarder
}
// newForwardInterceptor creates a new forwardInterceptor.
func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) *forwardInterceptor {
func newForwardInterceptor(htlcSwitch htlcswitch.InterceptableHtlcForwarder,
stream Router_HtlcInterceptorServer) *forwardInterceptor {
return &forwardInterceptor{
server: server,
stream: stream,
holdForwards: make(
map[channeldb.CircuitKey]htlcswitch.InterceptedForward),
quit: make(chan struct{}),
intercepted: make(chan htlcswitch.InterceptedForward),
htlcSwitch: htlcSwitch,
stream: stream,
}
}
@ -67,42 +49,18 @@ func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer)
// To coordinate all this and make sure it is safe for concurrent access all
// packets are sent to the main where they are handled.
func (r *forwardInterceptor) run() error {
// make sure we disconnect and resolves all remaining packets if any.
defer r.onDisconnect()
// Register our interceptor so we receive all forwarded packets.
interceptableForwarder := r.server.cfg.RouterBackend.InterceptableForwarder
interceptableForwarder.SetInterceptor(r.onIntercept)
defer interceptableForwarder.SetInterceptor(nil)
r.htlcSwitch.SetInterceptor(r.onIntercept)
defer r.htlcSwitch.SetInterceptor(nil)
// start a go routine that reads client resolutions.
errChan := make(chan error)
resolutionRequests := make(chan *ForwardHtlcInterceptResponse)
r.wg.Add(1)
go r.readClientResponses(resolutionRequests, errChan)
// run the main loop that synchronizes both sides input into one go routine.
for {
select {
case intercepted := <-r.intercepted:
log.Tracef("sending intercepted packet to client %v", intercepted)
// in case we couldn't forward we exit the loop and drain the
// current interceptor as this indicates on a connection problem.
if err := r.holdAndForwardToClient(intercepted); err != nil {
return err
}
case resolution := <-resolutionRequests:
log.Tracef("resolving intercepted packet %v", resolution)
// in case we couldn't resolve we just add a log line since this
// does not indicate on any connection problem.
if err := r.resolveFromClient(resolution); err != nil {
log.Warnf("client resolution of intercepted "+
"packet failed %v", err)
}
case err := <-errChan:
resp, err := r.stream.Recv()
if err != nil {
return err
}
if err := r.resolveFromClient(resp); err != nil {
return err
case <-r.server.quit:
return nil
}
}
}
@ -111,54 +69,14 @@ func (r *forwardInterceptor) run() error {
// packet. Our interceptor makes sure we hold the packet and then signal to the
// main loop to handle the packet. We only return true if we were able
// to deliver the packet to the main loop.
func (r *forwardInterceptor) onIntercept(p htlcswitch.InterceptedForward) bool {
select {
case r.intercepted <- p:
return true
case <-r.quit:
return false
case <-r.server.quit:
return false
}
}
func (r *forwardInterceptor) onIntercept(
htlc htlcswitch.InterceptedPacket) error {
func (r *forwardInterceptor) readClientResponses(
resolutionChan chan *ForwardHtlcInterceptResponse, errChan chan error) {
log.Tracef("Sending intercepted packet to client %v", htlc)
defer r.wg.Done()
for {
resp, err := r.stream.Recv()
if err != nil {
errChan <- err
return
}
// Now that we have the response from the RPC client, send it to
// the responses chan.
select {
case resolutionChan <- resp:
case <-r.quit:
return
case <-r.server.quit:
return
}
}
}
// holdAndForwardToClient forwards the intercepted htlc to the client.
func (r *forwardInterceptor) holdAndForwardToClient(
forward htlcswitch.InterceptedForward) error {
htlc := forward.Packet()
inKey := htlc.IncomingCircuit
// Ignore already held htlcs.
if _, ok := r.holdForwards[inKey]; ok {
return nil
}
// First hold the forward, then send to client.
r.holdForwards[inKey] = forward
interceptionRequest := &ForwardHtlcInterceptRequest{
IncomingCircuitKey: &CircuitKey{
ChanId: inKey.ChanID.ToUint64(),
@ -181,20 +99,19 @@ func (r *forwardInterceptor) holdAndForwardToClient(
func (r *forwardInterceptor) resolveFromClient(
in *ForwardHtlcInterceptResponse) error {
log.Tracef("Resolving intercepted packet %v", in)
circuitKey := channeldb.CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId),
HtlcID: in.IncomingCircuitKey.HtlcId,
}
var interceptedForward htlcswitch.InterceptedForward
interceptedForward, ok := r.holdForwards[circuitKey]
if !ok {
return ErrFwdNotExists
}
delete(r.holdForwards, circuitKey)
switch in.Action {
case ResolveHoldForwardAction_RESUME:
return interceptedForward.Resume()
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionResume,
})
case ResolveHoldForwardAction_FAIL:
// Fail with an encrypted reason.
@ -219,7 +136,11 @@ func (r *forwardInterceptor) resolveFromClient(
)
}
return interceptedForward.Fail(in.FailureMessage)
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionFail,
FailureMessage: in.FailureMessage,
})
}
var code lnwire.FailCode
@ -244,14 +165,11 @@ func (r *forwardInterceptor) resolveFromClient(
)
}
err := interceptedForward.FailWithCode(code)
if err == htlcswitch.ErrUnsupportedFailureCode {
return status.Errorf(
codes.InvalidArgument, err.Error(),
)
}
return err
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionFail,
FailureCode: code,
})
case ResolveHoldForwardAction_SETTLE:
if in.Preimage == nil {
@ -261,7 +179,12 @@ func (r *forwardInterceptor) resolveFromClient(
if err != nil {
return err
}
return interceptedForward.Settle(preimage)
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionSettle,
Preimage: preimage,
})
default:
return status.Errorf(
@ -270,20 +193,3 @@ func (r *forwardInterceptor) resolveFromClient(
)
}
}
// onDisconnect removes all previousely held forwards from
// the store. Before they are removed it ensure to resume as the default
// behavior.
func (r *forwardInterceptor) onDisconnect() {
// Then close the channel so all go routine will exit.
close(r.quit)
log.Infof("RPC interceptor disconnected, resolving held packets")
for key, forward := range r.holdForwards {
if err := forward.Resume(); err != nil {
log.Errorf("failed to resume hold forward %v", err)
}
delete(r.holdForwards, key)
}
r.wg.Wait()
}

View File

@ -890,7 +890,9 @@ func (s *Server) HtlcInterceptor(stream Router_HtlcInterceptorServer) error {
defer atomic.CompareAndSwapInt32(&s.forwardInterceptorActive, 1, 0)
// run the forward interceptor.
return newForwardInterceptor(s, stream).run()
return newForwardInterceptor(
s.cfg.RouterBackend.InterceptableForwarder, stream,
).run()
}
func extractOutPoint(req *UpdateChanStatusRequest) (*wire.OutPoint, error) {

View File

@ -367,7 +367,9 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
Switch: mockSwitch,
ChanActiveTimeout: chanActiveTimeout,
InterceptSwitch: htlcswitch.NewInterceptableSwitch(nil),
InterceptSwitch: htlcswitch.NewInterceptableSwitch(
nil, false,
),
ChannelDB: dbAlice.ChannelStateDB(),
FeeEstimator: estimator,

View File

@ -347,6 +347,9 @@
; used as a hop.
; rejecthtlc=true
; If true, all HTLCs will be held until they are handled by an interceptor
; requireinterceptor=true
; If true, will apply a randomized staggering between 0s and 30s when
; reconnecting to persistent peers on startup. The first 10 reconnections will be
; attempted instantly, regardless of the flag's value

View File

@ -654,7 +654,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
if err != nil {
return nil, err
}
s.interceptableSwitch = htlcswitch.NewInterceptableSwitch(s.htlcSwitch)
s.interceptableSwitch = htlcswitch.NewInterceptableSwitch(
s.htlcSwitch, s.cfg.RequireInterceptor,
)
chanStatusMgrCfg := &netann.ChanStatusConfig{
ChanStatusSampleInterval: cfg.ChanStatusSampleInterval,
@ -1786,6 +1788,12 @@ func (s *server) Start() error {
}
cleanup = cleanup.add(s.htlcSwitch.Stop)
if err := s.interceptableSwitch.Start(); err != nil {
startErr = err
return
}
cleanup = cleanup.add(s.interceptableSwitch.Stop)
if err := s.chainArb.Start(); err != nil {
startErr = err
return