routerrpc+htlcswitch: move intercepted htlc tracking to switch

In this commit we move the tracking of the outstanding intercepted htlcs
to InterceptableSwitch. This is a preparation for making the htlc
interceptor required.

Required interception involves tracking outstanding htlcs across
multiple grpc client sessions. The per-session routerrpc
forwardInterceptor object is therefore no longer the best place for
that.
This commit is contained in:
Joost Jager
2022-02-07 08:53:10 +01:00
parent 95c270d1f8
commit 169f0c0bf4
6 changed files with 379 additions and 207 deletions

View File

@@ -2,7 +2,6 @@ package routerrpc
import (
"errors"
"sync"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch"
@@ -27,36 +26,19 @@ var (
// interceptor streaming session.
// It is created when the stream opens and disconnects when the stream closes.
type forwardInterceptor struct {
// server is the Server reference
server *Server
// holdForwards is a map of current hold forwards and their corresponding
// ForwardResolver.
holdForwards map[channeldb.CircuitKey]htlcswitch.InterceptedForward
// stream is the bidirectional RPC stream
stream Router_HtlcInterceptorServer
// quit is a channel that is closed when this forwardInterceptor is shutting
// down.
quit chan struct{}
// intercepted is where we stream all intercepted packets coming from
// the switch.
intercepted chan htlcswitch.InterceptedForward
wg sync.WaitGroup
htlcSwitch htlcswitch.InterceptableHtlcForwarder
}
// newForwardInterceptor creates a new forwardInterceptor.
func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) *forwardInterceptor {
func newForwardInterceptor(htlcSwitch htlcswitch.InterceptableHtlcForwarder,
stream Router_HtlcInterceptorServer) *forwardInterceptor {
return &forwardInterceptor{
server: server,
stream: stream,
holdForwards: make(
map[channeldb.CircuitKey]htlcswitch.InterceptedForward),
quit: make(chan struct{}),
intercepted: make(chan htlcswitch.InterceptedForward),
htlcSwitch: htlcSwitch,
stream: stream,
}
}
@@ -67,42 +49,18 @@ func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer)
// To coordinate all this and make sure it is safe for concurrent access all
// packets are sent to the main where they are handled.
func (r *forwardInterceptor) run() error {
// make sure we disconnect and resolves all remaining packets if any.
defer r.onDisconnect()
// Register our interceptor so we receive all forwarded packets.
interceptableForwarder := r.server.cfg.RouterBackend.InterceptableForwarder
interceptableForwarder.SetInterceptor(r.onIntercept)
defer interceptableForwarder.SetInterceptor(nil)
r.htlcSwitch.SetInterceptor(r.onIntercept)
defer r.htlcSwitch.SetInterceptor(nil)
// start a go routine that reads client resolutions.
errChan := make(chan error)
resolutionRequests := make(chan *ForwardHtlcInterceptResponse)
r.wg.Add(1)
go r.readClientResponses(resolutionRequests, errChan)
// run the main loop that synchronizes both sides input into one go routine.
for {
select {
case intercepted := <-r.intercepted:
log.Tracef("sending intercepted packet to client %v", intercepted)
// in case we couldn't forward we exit the loop and drain the
// current interceptor as this indicates on a connection problem.
if err := r.holdAndForwardToClient(intercepted); err != nil {
return err
}
case resolution := <-resolutionRequests:
log.Tracef("resolving intercepted packet %v", resolution)
// in case we couldn't resolve we just add a log line since this
// does not indicate on any connection problem.
if err := r.resolveFromClient(resolution); err != nil {
log.Warnf("client resolution of intercepted "+
"packet failed %v", err)
}
case err := <-errChan:
resp, err := r.stream.Recv()
if err != nil {
return err
}
if err := r.resolveFromClient(resp); err != nil {
return err
case <-r.server.quit:
return nil
}
}
}
@@ -111,54 +69,14 @@ func (r *forwardInterceptor) run() error {
// packet. Our interceptor makes sure we hold the packet and then signal to the
// main loop to handle the packet. We only return true if we were able
// to deliver the packet to the main loop.
func (r *forwardInterceptor) onIntercept(p htlcswitch.InterceptedForward) bool {
select {
case r.intercepted <- p:
return true
case <-r.quit:
return false
case <-r.server.quit:
return false
}
}
func (r *forwardInterceptor) onIntercept(
htlc htlcswitch.InterceptedPacket) error {
func (r *forwardInterceptor) readClientResponses(
resolutionChan chan *ForwardHtlcInterceptResponse, errChan chan error) {
log.Tracef("Sending intercepted packet to client %v", htlc)
defer r.wg.Done()
for {
resp, err := r.stream.Recv()
if err != nil {
errChan <- err
return
}
// Now that we have the response from the RPC client, send it to
// the responses chan.
select {
case resolutionChan <- resp:
case <-r.quit:
return
case <-r.server.quit:
return
}
}
}
// holdAndForwardToClient forwards the intercepted htlc to the client.
func (r *forwardInterceptor) holdAndForwardToClient(
forward htlcswitch.InterceptedForward) error {
htlc := forward.Packet()
inKey := htlc.IncomingCircuit
// Ignore already held htlcs.
if _, ok := r.holdForwards[inKey]; ok {
return nil
}
// First hold the forward, then send to client.
r.holdForwards[inKey] = forward
interceptionRequest := &ForwardHtlcInterceptRequest{
IncomingCircuitKey: &CircuitKey{
ChanId: inKey.ChanID.ToUint64(),
@@ -181,20 +99,19 @@ func (r *forwardInterceptor) holdAndForwardToClient(
func (r *forwardInterceptor) resolveFromClient(
in *ForwardHtlcInterceptResponse) error {
log.Tracef("Resolving intercepted packet %v", in)
circuitKey := channeldb.CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId),
HtlcID: in.IncomingCircuitKey.HtlcId,
}
var interceptedForward htlcswitch.InterceptedForward
interceptedForward, ok := r.holdForwards[circuitKey]
if !ok {
return ErrFwdNotExists
}
delete(r.holdForwards, circuitKey)
switch in.Action {
case ResolveHoldForwardAction_RESUME:
return interceptedForward.Resume()
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionResume,
})
case ResolveHoldForwardAction_FAIL:
// Fail with an encrypted reason.
@@ -219,7 +136,11 @@ func (r *forwardInterceptor) resolveFromClient(
)
}
return interceptedForward.Fail(in.FailureMessage)
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionFail,
FailureMessage: in.FailureMessage,
})
}
var code lnwire.FailCode
@@ -244,14 +165,11 @@ func (r *forwardInterceptor) resolveFromClient(
)
}
err := interceptedForward.FailWithCode(code)
if err == htlcswitch.ErrUnsupportedFailureCode {
return status.Errorf(
codes.InvalidArgument, err.Error(),
)
}
return err
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionFail,
FailureCode: code,
})
case ResolveHoldForwardAction_SETTLE:
if in.Preimage == nil {
@@ -261,7 +179,12 @@ func (r *forwardInterceptor) resolveFromClient(
if err != nil {
return err
}
return interceptedForward.Settle(preimage)
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionSettle,
Preimage: preimage,
})
default:
return status.Errorf(
@@ -270,20 +193,3 @@ func (r *forwardInterceptor) resolveFromClient(
)
}
}
// onDisconnect removes all previousely held forwards from
// the store. Before they are removed it ensure to resume as the default
// behavior.
func (r *forwardInterceptor) onDisconnect() {
// Then close the channel so all go routine will exit.
close(r.quit)
log.Infof("RPC interceptor disconnected, resolving held packets")
for key, forward := range r.holdForwards {
if err := forward.Resume(); err != nil {
log.Errorf("failed to resume hold forward %v", err)
}
delete(r.holdForwards, key)
}
r.wg.Wait()
}

View File

@@ -890,7 +890,9 @@ func (s *Server) HtlcInterceptor(stream Router_HtlcInterceptorServer) error {
defer atomic.CompareAndSwapInt32(&s.forwardInterceptorActive, 1, 0)
// run the forward interceptor.
return newForwardInterceptor(s, stream).run()
return newForwardInterceptor(
s.cfg.RouterBackend.InterceptableForwarder, stream,
).run()
}
func extractOutPoint(req *UpdateChanStatusRequest) (*wire.OutPoint, error) {