diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 72e8ae3b7..13f70dcf2 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -30,64 +30,260 @@ var ( // Settle - routes UpdateFulfillHTLC to the originating link. // Fail - routes UpdateFailHTLC to the originating link. type InterceptableSwitch struct { - sync.RWMutex - // htlcSwitch is the underline switch htlcSwitch *Switch - // fwdInterceptor is the callback that is called for each forward of - // an incoming htlc. It should return true if it is interested in handling - // it. - fwdInterceptor ForwardInterceptor + // intercepted is where we stream all intercepted packets coming from + // the switch. + intercepted chan *interceptedPackets + + // resolutionChan is where we stream all responses coming from the + // interceptor client. + resolutionChan chan *fwdResolution + + // interceptorRegistration is a channel that we use to synchronize + // client connect and disconnect. + interceptorRegistration chan ForwardInterceptor + + // interceptor is the handler for intercepted packets. + interceptor ForwardInterceptor + + // holdForwards keeps track of outstanding intercepted forwards. + holdForwards map[channeldb.CircuitKey]InterceptedForward + + wg sync.WaitGroup + quit chan struct{} +} + +type interceptedPackets struct { + packets []*htlcPacket + linkQuit chan struct{} +} + +// FwdAction defines the various resolution types. +type FwdAction int + +const ( + // FwdActionResume forwards the intercepted packet to the switch. + FwdActionResume FwdAction = iota + + // FwdActionSettle settles the intercepted packet with a preimage. + FwdActionSettle + + // FwdActionFail fails the intercepted packet back to the sender. + FwdActionFail +) + +// FwdResolution defines the action to be taken on an intercepted packet. +type FwdResolution struct { + // Key is the incoming circuit key of the htlc. + Key channeldb.CircuitKey + + // Action is the action to take on the intercepted htlc. + Action FwdAction + + // Preimage is the preimage that is to be used for settling if Action is + // FwdActionSettle. + Preimage lntypes.Preimage + + // FailureMessage is the encrypted failure message that is to be passed + // back to the sender if action is FwdActionFail. + FailureMessage []byte + + // FailureCode is the failure code that is to be passed back to the + // sender if action is FwdActionFail. + FailureCode lnwire.FailCode +} + +type fwdResolution struct { + resolution *FwdResolution + errChan chan error } // NewInterceptableSwitch returns an instance of InterceptableSwitch. func NewInterceptableSwitch(s *Switch) *InterceptableSwitch { - return &InterceptableSwitch{htlcSwitch: s} + return &InterceptableSwitch{ + htlcSwitch: s, + intercepted: make(chan *interceptedPackets), + interceptorRegistration: make(chan ForwardInterceptor), + holdForwards: make(map[channeldb.CircuitKey]InterceptedForward), + resolutionChan: make(chan *fwdResolution), + + quit: make(chan struct{}), + } } -// SetInterceptor sets the ForwardInterceptor to be used. +// SetInterceptor sets the ForwardInterceptor to be used. A nil argument +// unregisters the current interceptor. func (s *InterceptableSwitch) SetInterceptor( interceptor ForwardInterceptor) { - s.Lock() - defer s.Unlock() - s.fwdInterceptor = interceptor + // Synchronize setting the handler with the main loop to prevent race + // conditions. + select { + case s.interceptorRegistration <- interceptor: + + case <-s.quit: + } } -// ForwardPackets attempts to forward the batch of htlcs through the -// switch, any failed packets will be returned to the provided -// ChannelLink. The link's quit signal should be provided to allow +func (s *InterceptableSwitch) Start() error { + s.wg.Add(1) + go func() { + defer s.wg.Done() + + s.run() + }() + + return nil +} + +func (s *InterceptableSwitch) Stop() error { + close(s.quit) + s.wg.Wait() + + return nil +} + +func (s *InterceptableSwitch) run() { + for { + select { + // An interceptor registration or de-registration came in. + case interceptor := <-s.interceptorRegistration: + s.setInterceptor(interceptor) + + case packets := <-s.intercepted: + var notIntercepted []*htlcPacket + for _, p := range packets.packets { + if s.interceptor == nil || + !s.interceptForward(p) { + + notIntercepted = append( + notIntercepted, p, + ) + } + } + err := s.htlcSwitch.ForwardPackets( + packets.linkQuit, notIntercepted..., + ) + if err != nil { + log.Errorf("Cannot forward packets: %v", err) + } + + case res := <-s.resolutionChan: + res.errChan <- s.resolve(res.resolution) + + case <-s.quit: + return + } + } +} + +func (s *InterceptableSwitch) sendForward(fwd InterceptedForward) { + err := s.interceptor(fwd.Packet()) + if err != nil { + // Only log the error. If we couldn't send the packet, we assume + // that the interceptor will reconnect so that we can retry. + log.Debugf("Interceptor cannot handle forward: %v", err) + } +} + +func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) { + s.interceptor = interceptor + + if interceptor != nil { + log.Debugf("Interceptor connected") + + return + } + + log.Infof("Interceptor disconnected, resolving held packets") + + for _, fwd := range s.holdForwards { + if err := fwd.Resume(); err != nil { + log.Errorf("Failed to resume hold forward %v", err) + } + } + s.holdForwards = make(map[channeldb.CircuitKey]InterceptedForward) +} + +func (s *InterceptableSwitch) resolve(res *FwdResolution) error { + intercepted, ok := s.holdForwards[res.Key] + if !ok { + return fmt.Errorf("fwd %v not found", res.Key) + } + delete(s.holdForwards, res.Key) + + switch res.Action { + case FwdActionResume: + return intercepted.Resume() + + case FwdActionSettle: + return intercepted.Settle(res.Preimage) + + case FwdActionFail: + if len(res.FailureMessage) > 0 { + return intercepted.Fail(res.FailureMessage) + } + + return intercepted.FailWithCode(res.FailureCode) + + default: + return fmt.Errorf("unrecognized action %v", res.Action) + } +} + +// Resolve resolves an intercepted packet. +func (s *InterceptableSwitch) Resolve(res *FwdResolution) error { + internalRes := &fwdResolution{ + resolution: res, + errChan: make(chan error, 1), + } + + select { + case s.resolutionChan <- internalRes: + + case <-s.quit: + return errors.New("switch shutting down") + } + + select { + case err := <-internalRes.errChan: + return err + + case <-s.quit: + return errors.New("switch shutting down") + } +} + +// ForwardPackets attempts to forward the batch of htlcs to a connected +// interceptor. If the interceptor signals the resume action, the htlcs are +// forwarded to the switch. The link's quit signal should be provided to allow // cancellation of forwarding during link shutdown. func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, packets ...*htlcPacket) error { - var interceptor ForwardInterceptor - s.Lock() - interceptor = s.fwdInterceptor - s.Unlock() + // Synchronize with the main event loop. This should be light in the + // case where there is no interceptor. + select { + case s.intercepted <- &interceptedPackets{ + packets: packets, + linkQuit: linkQuit, + }: - // Optimize for the case we don't have an interceptor. - if interceptor == nil { - return s.htlcSwitch.ForwardPackets(linkQuit, packets...) + case <-linkQuit: + log.Debugf("Forward cancelled because link quit") + + case <-s.quit: + return errors.New("interceptable switch quit") } - var notIntercepted []*htlcPacket - for _, p := range packets { - if !s.interceptForward(p, interceptor, linkQuit) { - notIntercepted = append(notIntercepted, p) - } - } - return s.htlcSwitch.ForwardPackets(linkQuit, notIntercepted...) + return nil } -// interceptForward checks if there is any external interceptor interested in -// this packet. Currently only htlc type of UpdateAddHTLC that are forwarded -// are being checked for interception. It can be extended in the future given -// the right use case. -func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, - interceptor ForwardInterceptor, linkQuit chan struct{}) bool { - +// interceptForward forwards the packet to the external interceptor after +// checking the interception criteria. +func (s *InterceptableSwitch) interceptForward(packet *htlcPacket) bool { switch htlc := packet.htlc.(type) { case *lnwire.UpdateAddHTLC: // We are not interested in intercepting initiated payments. @@ -95,15 +291,28 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, return false } + inKey := channeldb.CircuitKey{ + ChanID: packet.incomingChanID, + HtlcID: packet.incomingHTLCID, + } + + // Ignore already held htlcs. + if _, ok := s.holdForwards[inKey]; ok { + return true + } + intercepted := &interceptedForward{ - linkQuit: linkQuit, htlc: htlc, packet: packet, htlcSwitch: s.htlcSwitch, } - // If this htlc was intercepted, don't handle the forward. - return interceptor(intercepted) + s.holdForwards[inKey] = intercepted + + s.sendForward(intercepted) + + return true + default: return false } @@ -113,7 +322,6 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, // It is passed from the switch to external interceptors that are interested // in holding forwards and resolve them manually. type interceptedForward struct { - linkQuit chan struct{} htlc *lnwire.UpdateAddHTLC packet *htlcPacket htlcSwitch *Switch @@ -139,10 +347,12 @@ func (f *interceptedForward) Packet() InterceptedPacket { // Resume resumes the default behavior as if the packet was not intercepted. func (f *interceptedForward) Resume() error { - return f.htlcSwitch.ForwardPackets(f.linkQuit, f.packet) + // Forward to the switch. A link quit channel isn't needed, because we + // are on a different thread now. + return f.htlcSwitch.ForwardPackets(nil, f.packet) } -// Fail notifies the intention to fail an existing hold forward with an +// Fail notifies the intention to Fail an existing hold forward with an // encrypted failure reason. func (f *interceptedForward) Fail(reason []byte) error { obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason) diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 4b201ee1e..1b80adab3 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -234,6 +234,9 @@ type TowerClient interface { type InterceptableHtlcForwarder interface { // SetInterceptor sets a ForwardInterceptor. SetInterceptor(interceptor ForwardInterceptor) + + // Resolve resolves an intercepted packet. + Resolve(res *FwdResolution) error } // ForwardInterceptor is a function that is invoked from the switch for every @@ -242,7 +245,7 @@ type InterceptableHtlcForwarder interface { // to resolve it manually later in case it is held. // The return value indicates if this handler will take control of this forward // and resolve it later or let the switch execute its default behavior. -type ForwardInterceptor func(InterceptedForward) bool +type ForwardInterceptor func(InterceptedPacket) error // InterceptedPacket contains the relevant information for the interceptor about // an htlc. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index ffbca4907..4c22b5e8e 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3140,32 +3140,29 @@ func getThreeHopEvents(channels *clusterChannels, htlcID uint64, } type mockForwardInterceptor struct { - intercepted InterceptedForward + t *testing.T + + interceptedChan chan InterceptedPacket } func (m *mockForwardInterceptor) InterceptForwardHtlc( - intercepted InterceptedForward) bool { + intercepted InterceptedPacket) error { - m.intercepted = intercepted - return true + m.interceptedChan <- intercepted + + return nil } -func (m *mockForwardInterceptor) settle(preimage lntypes.Preimage) error { - return m.intercepted.Settle(preimage) -} +func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket { + select { + case p := <-m.interceptedChan: + return p -func (m *mockForwardInterceptor) fail(reason []byte) error { - return m.intercepted.Fail(reason) -} + case <-time.After(time.Second): + require.Fail(m.t, "timeout") -func (m *mockForwardInterceptor) failWithCode( - code lnwire.FailCode) error { - - return m.intercepted.FailWithCode(code) -} - -func (m *mockForwardInterceptor) resume() error { - return m.intercepted.Resume() + return InterceptedPacket{} + } } func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) { @@ -3272,12 +3269,17 @@ func TestSwitchHoldForward(t *testing.T) { }, } - forwardInterceptor := &mockForwardInterceptor{} + forwardInterceptor := &mockForwardInterceptor{ + t: t, + interceptedChan: make(chan InterceptedPacket), + } switchForwardInterceptor := NewInterceptableSwitch(s) + require.NoError(t, switchForwardInterceptor.Start()) + switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc) linkQuit := make(chan struct{}) - // Test resume a hold forward + // Test resume a hold forward. assertNumCircuits(t, s, 0, 0) if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { t.Fatalf("can't forward htlc packet: %v", err) @@ -3285,9 +3287,10 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) - if err := forwardInterceptor.resume(); err != nil { - t.Fatalf("failed to resume forward") - } + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Action: FwdActionResume, + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + })) assertOutgoingLinkReceive(t, bobChannelLink, true) assertNumCircuits(t, s, 1, 1) @@ -3306,16 +3309,46 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) + // Test resume a hold forward after disconnection. + err = switchForwardInterceptor.ForwardPackets(nil, ogPacket) + require.NoError(t, err) + + // Wait until the packet is offered to the interceptor. + _ = forwardInterceptor.getIntercepted() + + // No forward expected yet. + assertNumCircuits(t, s, 0, 0) + assertOutgoingLinkReceive(t, bobChannelLink, false) + + // Disconnect should resume the forwarding. + switchForwardInterceptor.SetInterceptor(nil) + + assertOutgoingLinkReceive(t, bobChannelLink, true) + assertNumCircuits(t, s, 1, 1) + + // Settle the htlc to close the circuit. + settle.outgoingHTLCID = 1 + require.NoError(t, switchForwardInterceptor.ForwardPackets(nil, settle)) + + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) + // Test failing a hold forward + switchForwardInterceptor.SetInterceptor( + forwardInterceptor.InterceptForwardHtlc, + ) + if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { t.Fatalf("can't forward htlc packet: %v", err) } assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) - if err := forwardInterceptor.fail(nil); err != nil { - t.Fatalf("failed to cancel forward %v", err) - } + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Action: FwdActionFail, + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + FailureCode: lnwire.CodeTemporaryChannelFailure, + })) assertOutgoingLinkReceive(t, bobChannelLink, false) assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) @@ -3328,7 +3361,11 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, bobChannelLink, false) reason := lnwire.OpaqueReason([]byte{1, 2, 3}) - require.NoError(t, forwardInterceptor.fail(reason)) + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Action: FwdActionFail, + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + FailureMessage: reason, + })) assertOutgoingLinkReceive(t, bobChannelLink, false) packet := assertOutgoingLinkReceive(t, aliceChannelLink, true) @@ -3345,7 +3382,11 @@ func TestSwitchHoldForward(t *testing.T) { assertOutgoingLinkReceive(t, bobChannelLink, false) code := lnwire.CodeInvalidOnionKey - require.NoError(t, forwardInterceptor.failWithCode(code)) + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Action: FwdActionFail, + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + FailureCode: code, + })) assertOutgoingLinkReceive(t, bobChannelLink, false) packet = assertOutgoingLinkReceive(t, aliceChannelLink, true) @@ -3369,12 +3410,16 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) - if err := forwardInterceptor.settle(preimage); err != nil { - t.Fatal("failed to cancel forward") - } + require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{ + Key: forwardInterceptor.getIntercepted().IncomingCircuit, + Action: FwdActionSettle, + Preimage: preimage, + })) assertOutgoingLinkReceive(t, bobChannelLink, false) assertOutgoingLinkReceive(t, aliceChannelLink, true) assertNumCircuits(t, s, 0, 0) + + require.NoError(t, switchForwardInterceptor.Stop()) } // TestSwitchDustForwarding tests that the switch properly fails HTLC's which diff --git a/lnrpc/routerrpc/forward_interceptor.go b/lnrpc/routerrpc/forward_interceptor.go index c52a23d2a..9fba60f1f 100644 --- a/lnrpc/routerrpc/forward_interceptor.go +++ b/lnrpc/routerrpc/forward_interceptor.go @@ -2,7 +2,6 @@ package routerrpc import ( "errors" - "sync" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" @@ -27,36 +26,19 @@ var ( // interceptor streaming session. // It is created when the stream opens and disconnects when the stream closes. type forwardInterceptor struct { - // server is the Server reference - server *Server - - // holdForwards is a map of current hold forwards and their corresponding - // ForwardResolver. - holdForwards map[channeldb.CircuitKey]htlcswitch.InterceptedForward - // stream is the bidirectional RPC stream stream Router_HtlcInterceptorServer - // quit is a channel that is closed when this forwardInterceptor is shutting - // down. - quit chan struct{} - - // intercepted is where we stream all intercepted packets coming from - // the switch. - intercepted chan htlcswitch.InterceptedForward - - wg sync.WaitGroup + htlcSwitch htlcswitch.InterceptableHtlcForwarder } // newForwardInterceptor creates a new forwardInterceptor. -func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) *forwardInterceptor { +func newForwardInterceptor(htlcSwitch htlcswitch.InterceptableHtlcForwarder, + stream Router_HtlcInterceptorServer) *forwardInterceptor { + return &forwardInterceptor{ - server: server, - stream: stream, - holdForwards: make( - map[channeldb.CircuitKey]htlcswitch.InterceptedForward), - quit: make(chan struct{}), - intercepted: make(chan htlcswitch.InterceptedForward), + htlcSwitch: htlcSwitch, + stream: stream, } } @@ -67,42 +49,18 @@ func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) // To coordinate all this and make sure it is safe for concurrent access all // packets are sent to the main where they are handled. func (r *forwardInterceptor) run() error { - // make sure we disconnect and resolves all remaining packets if any. - defer r.onDisconnect() - // Register our interceptor so we receive all forwarded packets. - interceptableForwarder := r.server.cfg.RouterBackend.InterceptableForwarder - interceptableForwarder.SetInterceptor(r.onIntercept) - defer interceptableForwarder.SetInterceptor(nil) + r.htlcSwitch.SetInterceptor(r.onIntercept) + defer r.htlcSwitch.SetInterceptor(nil) - // start a go routine that reads client resolutions. - errChan := make(chan error) - resolutionRequests := make(chan *ForwardHtlcInterceptResponse) - r.wg.Add(1) - go r.readClientResponses(resolutionRequests, errChan) - - // run the main loop that synchronizes both sides input into one go routine. for { - select { - case intercepted := <-r.intercepted: - log.Tracef("sending intercepted packet to client %v", intercepted) - // in case we couldn't forward we exit the loop and drain the - // current interceptor as this indicates on a connection problem. - if err := r.holdAndForwardToClient(intercepted); err != nil { - return err - } - case resolution := <-resolutionRequests: - log.Tracef("resolving intercepted packet %v", resolution) - // in case we couldn't resolve we just add a log line since this - // does not indicate on any connection problem. - if err := r.resolveFromClient(resolution); err != nil { - log.Warnf("client resolution of intercepted "+ - "packet failed %v", err) - } - case err := <-errChan: + resp, err := r.stream.Recv() + if err != nil { + return err + } + + if err := r.resolveFromClient(resp); err != nil { return err - case <-r.server.quit: - return nil } } } @@ -111,54 +69,14 @@ func (r *forwardInterceptor) run() error { // packet. Our interceptor makes sure we hold the packet and then signal to the // main loop to handle the packet. We only return true if we were able // to deliver the packet to the main loop. -func (r *forwardInterceptor) onIntercept(p htlcswitch.InterceptedForward) bool { - select { - case r.intercepted <- p: - return true - case <-r.quit: - return false - case <-r.server.quit: - return false - } -} +func (r *forwardInterceptor) onIntercept( + htlc htlcswitch.InterceptedPacket) error { -func (r *forwardInterceptor) readClientResponses( - resolutionChan chan *ForwardHtlcInterceptResponse, errChan chan error) { + log.Tracef("Sending intercepted packet to client %v", htlc) - defer r.wg.Done() - for { - resp, err := r.stream.Recv() - if err != nil { - errChan <- err - return - } - - // Now that we have the response from the RPC client, send it to - // the responses chan. - select { - case resolutionChan <- resp: - case <-r.quit: - return - case <-r.server.quit: - return - } - } -} - -// holdAndForwardToClient forwards the intercepted htlc to the client. -func (r *forwardInterceptor) holdAndForwardToClient( - forward htlcswitch.InterceptedForward) error { - - htlc := forward.Packet() inKey := htlc.IncomingCircuit - // Ignore already held htlcs. - if _, ok := r.holdForwards[inKey]; ok { - return nil - } - // First hold the forward, then send to client. - r.holdForwards[inKey] = forward interceptionRequest := &ForwardHtlcInterceptRequest{ IncomingCircuitKey: &CircuitKey{ ChanId: inKey.ChanID.ToUint64(), @@ -181,20 +99,19 @@ func (r *forwardInterceptor) holdAndForwardToClient( func (r *forwardInterceptor) resolveFromClient( in *ForwardHtlcInterceptResponse) error { + log.Tracef("Resolving intercepted packet %v", in) + circuitKey := channeldb.CircuitKey{ ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId), HtlcID: in.IncomingCircuitKey.HtlcId, } - var interceptedForward htlcswitch.InterceptedForward - interceptedForward, ok := r.holdForwards[circuitKey] - if !ok { - return ErrFwdNotExists - } - delete(r.holdForwards, circuitKey) switch in.Action { case ResolveHoldForwardAction_RESUME: - return interceptedForward.Resume() + return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ + Key: circuitKey, + Action: htlcswitch.FwdActionResume, + }) case ResolveHoldForwardAction_FAIL: // Fail with an encrypted reason. @@ -219,7 +136,11 @@ func (r *forwardInterceptor) resolveFromClient( ) } - return interceptedForward.Fail(in.FailureMessage) + return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ + Key: circuitKey, + Action: htlcswitch.FwdActionFail, + FailureMessage: in.FailureMessage, + }) } var code lnwire.FailCode @@ -244,14 +165,11 @@ func (r *forwardInterceptor) resolveFromClient( ) } - err := interceptedForward.FailWithCode(code) - if err == htlcswitch.ErrUnsupportedFailureCode { - return status.Errorf( - codes.InvalidArgument, err.Error(), - ) - } - - return err + return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ + Key: circuitKey, + Action: htlcswitch.FwdActionFail, + FailureCode: code, + }) case ResolveHoldForwardAction_SETTLE: if in.Preimage == nil { @@ -261,7 +179,12 @@ func (r *forwardInterceptor) resolveFromClient( if err != nil { return err } - return interceptedForward.Settle(preimage) + + return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{ + Key: circuitKey, + Action: htlcswitch.FwdActionSettle, + Preimage: preimage, + }) default: return status.Errorf( @@ -270,20 +193,3 @@ func (r *forwardInterceptor) resolveFromClient( ) } } - -// onDisconnect removes all previousely held forwards from -// the store. Before they are removed it ensure to resume as the default -// behavior. -func (r *forwardInterceptor) onDisconnect() { - // Then close the channel so all go routine will exit. - close(r.quit) - - log.Infof("RPC interceptor disconnected, resolving held packets") - for key, forward := range r.holdForwards { - if err := forward.Resume(); err != nil { - log.Errorf("failed to resume hold forward %v", err) - } - delete(r.holdForwards, key) - } - r.wg.Wait() -} diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index dc14af4ef..0c60a62d5 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -890,7 +890,9 @@ func (s *Server) HtlcInterceptor(stream Router_HtlcInterceptorServer) error { defer atomic.CompareAndSwapInt32(&s.forwardInterceptorActive, 1, 0) // run the forward interceptor. - return newForwardInterceptor(s, stream).run() + return newForwardInterceptor( + s.cfg.RouterBackend.InterceptableForwarder, stream, + ).run() } func extractOutPoint(req *UpdateChanStatusRequest) (*wire.OutPoint, error) { diff --git a/server.go b/server.go index 632620526..b2d3b1c9d 100644 --- a/server.go +++ b/server.go @@ -1786,6 +1786,12 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.htlcSwitch.Stop) + if err := s.interceptableSwitch.Start(); err != nil { + startErr = err + return + } + cleanup = cleanup.add(s.interceptableSwitch.Stop) + if err := s.chainArb.Start(); err != nil { startErr = err return