mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-19 00:00:21 +02:00
routerrpc+htlcswitch: move intercepted htlc tracking to switch
In this commit we move the tracking of the outstanding intercepted htlcs to InterceptableSwitch. This is a preparation for making the htlc interceptor required. Required interception involves tracking outstanding htlcs across multiple grpc client sessions. The per-session routerrpc forwardInterceptor object is therefore no longer the best place for that.
This commit is contained in:
parent
95c270d1f8
commit
169f0c0bf4
@ -30,64 +30,260 @@ var (
|
|||||||
// Settle - routes UpdateFulfillHTLC to the originating link.
|
// Settle - routes UpdateFulfillHTLC to the originating link.
|
||||||
// Fail - routes UpdateFailHTLC to the originating link.
|
// Fail - routes UpdateFailHTLC to the originating link.
|
||||||
type InterceptableSwitch struct {
|
type InterceptableSwitch struct {
|
||||||
sync.RWMutex
|
|
||||||
|
|
||||||
// htlcSwitch is the underline switch
|
// htlcSwitch is the underline switch
|
||||||
htlcSwitch *Switch
|
htlcSwitch *Switch
|
||||||
|
|
||||||
// fwdInterceptor is the callback that is called for each forward of
|
// intercepted is where we stream all intercepted packets coming from
|
||||||
// an incoming htlc. It should return true if it is interested in handling
|
// the switch.
|
||||||
// it.
|
intercepted chan *interceptedPackets
|
||||||
fwdInterceptor ForwardInterceptor
|
|
||||||
|
// resolutionChan is where we stream all responses coming from the
|
||||||
|
// interceptor client.
|
||||||
|
resolutionChan chan *fwdResolution
|
||||||
|
|
||||||
|
// interceptorRegistration is a channel that we use to synchronize
|
||||||
|
// client connect and disconnect.
|
||||||
|
interceptorRegistration chan ForwardInterceptor
|
||||||
|
|
||||||
|
// interceptor is the handler for intercepted packets.
|
||||||
|
interceptor ForwardInterceptor
|
||||||
|
|
||||||
|
// holdForwards keeps track of outstanding intercepted forwards.
|
||||||
|
holdForwards map[channeldb.CircuitKey]InterceptedForward
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
quit chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type interceptedPackets struct {
|
||||||
|
packets []*htlcPacket
|
||||||
|
linkQuit chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FwdAction defines the various resolution types.
|
||||||
|
type FwdAction int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// FwdActionResume forwards the intercepted packet to the switch.
|
||||||
|
FwdActionResume FwdAction = iota
|
||||||
|
|
||||||
|
// FwdActionSettle settles the intercepted packet with a preimage.
|
||||||
|
FwdActionSettle
|
||||||
|
|
||||||
|
// FwdActionFail fails the intercepted packet back to the sender.
|
||||||
|
FwdActionFail
|
||||||
|
)
|
||||||
|
|
||||||
|
// FwdResolution defines the action to be taken on an intercepted packet.
|
||||||
|
type FwdResolution struct {
|
||||||
|
// Key is the incoming circuit key of the htlc.
|
||||||
|
Key channeldb.CircuitKey
|
||||||
|
|
||||||
|
// Action is the action to take on the intercepted htlc.
|
||||||
|
Action FwdAction
|
||||||
|
|
||||||
|
// Preimage is the preimage that is to be used for settling if Action is
|
||||||
|
// FwdActionSettle.
|
||||||
|
Preimage lntypes.Preimage
|
||||||
|
|
||||||
|
// FailureMessage is the encrypted failure message that is to be passed
|
||||||
|
// back to the sender if action is FwdActionFail.
|
||||||
|
FailureMessage []byte
|
||||||
|
|
||||||
|
// FailureCode is the failure code that is to be passed back to the
|
||||||
|
// sender if action is FwdActionFail.
|
||||||
|
FailureCode lnwire.FailCode
|
||||||
|
}
|
||||||
|
|
||||||
|
type fwdResolution struct {
|
||||||
|
resolution *FwdResolution
|
||||||
|
errChan chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInterceptableSwitch returns an instance of InterceptableSwitch.
|
// NewInterceptableSwitch returns an instance of InterceptableSwitch.
|
||||||
func NewInterceptableSwitch(s *Switch) *InterceptableSwitch {
|
func NewInterceptableSwitch(s *Switch) *InterceptableSwitch {
|
||||||
return &InterceptableSwitch{htlcSwitch: s}
|
return &InterceptableSwitch{
|
||||||
|
htlcSwitch: s,
|
||||||
|
intercepted: make(chan *interceptedPackets),
|
||||||
|
interceptorRegistration: make(chan ForwardInterceptor),
|
||||||
|
holdForwards: make(map[channeldb.CircuitKey]InterceptedForward),
|
||||||
|
resolutionChan: make(chan *fwdResolution),
|
||||||
|
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInterceptor sets the ForwardInterceptor to be used.
|
// SetInterceptor sets the ForwardInterceptor to be used. A nil argument
|
||||||
|
// unregisters the current interceptor.
|
||||||
func (s *InterceptableSwitch) SetInterceptor(
|
func (s *InterceptableSwitch) SetInterceptor(
|
||||||
interceptor ForwardInterceptor) {
|
interceptor ForwardInterceptor) {
|
||||||
|
|
||||||
s.Lock()
|
// Synchronize setting the handler with the main loop to prevent race
|
||||||
defer s.Unlock()
|
// conditions.
|
||||||
s.fwdInterceptor = interceptor
|
select {
|
||||||
|
case s.interceptorRegistration <- interceptor:
|
||||||
|
|
||||||
|
case <-s.quit:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForwardPackets attempts to forward the batch of htlcs through the
|
func (s *InterceptableSwitch) Start() error {
|
||||||
// switch, any failed packets will be returned to the provided
|
s.wg.Add(1)
|
||||||
// ChannelLink. The link's quit signal should be provided to allow
|
go func() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
s.run()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InterceptableSwitch) Stop() error {
|
||||||
|
close(s.quit)
|
||||||
|
s.wg.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InterceptableSwitch) run() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// An interceptor registration or de-registration came in.
|
||||||
|
case interceptor := <-s.interceptorRegistration:
|
||||||
|
s.setInterceptor(interceptor)
|
||||||
|
|
||||||
|
case packets := <-s.intercepted:
|
||||||
|
var notIntercepted []*htlcPacket
|
||||||
|
for _, p := range packets.packets {
|
||||||
|
if s.interceptor == nil ||
|
||||||
|
!s.interceptForward(p) {
|
||||||
|
|
||||||
|
notIntercepted = append(
|
||||||
|
notIntercepted, p,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := s.htlcSwitch.ForwardPackets(
|
||||||
|
packets.linkQuit, notIntercepted...,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Cannot forward packets: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case res := <-s.resolutionChan:
|
||||||
|
res.errChan <- s.resolve(res.resolution)
|
||||||
|
|
||||||
|
case <-s.quit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InterceptableSwitch) sendForward(fwd InterceptedForward) {
|
||||||
|
err := s.interceptor(fwd.Packet())
|
||||||
|
if err != nil {
|
||||||
|
// Only log the error. If we couldn't send the packet, we assume
|
||||||
|
// that the interceptor will reconnect so that we can retry.
|
||||||
|
log.Debugf("Interceptor cannot handle forward: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InterceptableSwitch) setInterceptor(interceptor ForwardInterceptor) {
|
||||||
|
s.interceptor = interceptor
|
||||||
|
|
||||||
|
if interceptor != nil {
|
||||||
|
log.Debugf("Interceptor connected")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Interceptor disconnected, resolving held packets")
|
||||||
|
|
||||||
|
for _, fwd := range s.holdForwards {
|
||||||
|
if err := fwd.Resume(); err != nil {
|
||||||
|
log.Errorf("Failed to resume hold forward %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.holdForwards = make(map[channeldb.CircuitKey]InterceptedForward)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InterceptableSwitch) resolve(res *FwdResolution) error {
|
||||||
|
intercepted, ok := s.holdForwards[res.Key]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("fwd %v not found", res.Key)
|
||||||
|
}
|
||||||
|
delete(s.holdForwards, res.Key)
|
||||||
|
|
||||||
|
switch res.Action {
|
||||||
|
case FwdActionResume:
|
||||||
|
return intercepted.Resume()
|
||||||
|
|
||||||
|
case FwdActionSettle:
|
||||||
|
return intercepted.Settle(res.Preimage)
|
||||||
|
|
||||||
|
case FwdActionFail:
|
||||||
|
if len(res.FailureMessage) > 0 {
|
||||||
|
return intercepted.Fail(res.FailureMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return intercepted.FailWithCode(res.FailureCode)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unrecognized action %v", res.Action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve resolves an intercepted packet.
|
||||||
|
func (s *InterceptableSwitch) Resolve(res *FwdResolution) error {
|
||||||
|
internalRes := &fwdResolution{
|
||||||
|
resolution: res,
|
||||||
|
errChan: make(chan error, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case s.resolutionChan <- internalRes:
|
||||||
|
|
||||||
|
case <-s.quit:
|
||||||
|
return errors.New("switch shutting down")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-internalRes.errChan:
|
||||||
|
return err
|
||||||
|
|
||||||
|
case <-s.quit:
|
||||||
|
return errors.New("switch shutting down")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardPackets attempts to forward the batch of htlcs to a connected
|
||||||
|
// interceptor. If the interceptor signals the resume action, the htlcs are
|
||||||
|
// forwarded to the switch. The link's quit signal should be provided to allow
|
||||||
// cancellation of forwarding during link shutdown.
|
// cancellation of forwarding during link shutdown.
|
||||||
func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{},
|
func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{},
|
||||||
packets ...*htlcPacket) error {
|
packets ...*htlcPacket) error {
|
||||||
|
|
||||||
var interceptor ForwardInterceptor
|
// Synchronize with the main event loop. This should be light in the
|
||||||
s.Lock()
|
// case where there is no interceptor.
|
||||||
interceptor = s.fwdInterceptor
|
select {
|
||||||
s.Unlock()
|
case s.intercepted <- &interceptedPackets{
|
||||||
|
packets: packets,
|
||||||
|
linkQuit: linkQuit,
|
||||||
|
}:
|
||||||
|
|
||||||
// Optimize for the case we don't have an interceptor.
|
case <-linkQuit:
|
||||||
if interceptor == nil {
|
log.Debugf("Forward cancelled because link quit")
|
||||||
return s.htlcSwitch.ForwardPackets(linkQuit, packets...)
|
|
||||||
|
case <-s.quit:
|
||||||
|
return errors.New("interceptable switch quit")
|
||||||
}
|
}
|
||||||
|
|
||||||
var notIntercepted []*htlcPacket
|
return nil
|
||||||
for _, p := range packets {
|
|
||||||
if !s.interceptForward(p, interceptor, linkQuit) {
|
|
||||||
notIntercepted = append(notIntercepted, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return s.htlcSwitch.ForwardPackets(linkQuit, notIntercepted...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// interceptForward checks if there is any external interceptor interested in
|
// interceptForward forwards the packet to the external interceptor after
|
||||||
// this packet. Currently only htlc type of UpdateAddHTLC that are forwarded
|
// checking the interception criteria.
|
||||||
// are being checked for interception. It can be extended in the future given
|
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket) bool {
|
||||||
// the right use case.
|
|
||||||
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
|
|
||||||
interceptor ForwardInterceptor, linkQuit chan struct{}) bool {
|
|
||||||
|
|
||||||
switch htlc := packet.htlc.(type) {
|
switch htlc := packet.htlc.(type) {
|
||||||
case *lnwire.UpdateAddHTLC:
|
case *lnwire.UpdateAddHTLC:
|
||||||
// We are not interested in intercepting initiated payments.
|
// We are not interested in intercepting initiated payments.
|
||||||
@ -95,15 +291,28 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inKey := channeldb.CircuitKey{
|
||||||
|
ChanID: packet.incomingChanID,
|
||||||
|
HtlcID: packet.incomingHTLCID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore already held htlcs.
|
||||||
|
if _, ok := s.holdForwards[inKey]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
intercepted := &interceptedForward{
|
intercepted := &interceptedForward{
|
||||||
linkQuit: linkQuit,
|
|
||||||
htlc: htlc,
|
htlc: htlc,
|
||||||
packet: packet,
|
packet: packet,
|
||||||
htlcSwitch: s.htlcSwitch,
|
htlcSwitch: s.htlcSwitch,
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this htlc was intercepted, don't handle the forward.
|
s.holdForwards[inKey] = intercepted
|
||||||
return interceptor(intercepted)
|
|
||||||
|
s.sendForward(intercepted)
|
||||||
|
|
||||||
|
return true
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -113,7 +322,6 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
|
|||||||
// It is passed from the switch to external interceptors that are interested
|
// It is passed from the switch to external interceptors that are interested
|
||||||
// in holding forwards and resolve them manually.
|
// in holding forwards and resolve them manually.
|
||||||
type interceptedForward struct {
|
type interceptedForward struct {
|
||||||
linkQuit chan struct{}
|
|
||||||
htlc *lnwire.UpdateAddHTLC
|
htlc *lnwire.UpdateAddHTLC
|
||||||
packet *htlcPacket
|
packet *htlcPacket
|
||||||
htlcSwitch *Switch
|
htlcSwitch *Switch
|
||||||
@ -139,10 +347,12 @@ func (f *interceptedForward) Packet() InterceptedPacket {
|
|||||||
|
|
||||||
// Resume resumes the default behavior as if the packet was not intercepted.
|
// Resume resumes the default behavior as if the packet was not intercepted.
|
||||||
func (f *interceptedForward) Resume() error {
|
func (f *interceptedForward) Resume() error {
|
||||||
return f.htlcSwitch.ForwardPackets(f.linkQuit, f.packet)
|
// Forward to the switch. A link quit channel isn't needed, because we
|
||||||
|
// are on a different thread now.
|
||||||
|
return f.htlcSwitch.ForwardPackets(nil, f.packet)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fail notifies the intention to fail an existing hold forward with an
|
// Fail notifies the intention to Fail an existing hold forward with an
|
||||||
// encrypted failure reason.
|
// encrypted failure reason.
|
||||||
func (f *interceptedForward) Fail(reason []byte) error {
|
func (f *interceptedForward) Fail(reason []byte) error {
|
||||||
obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason)
|
obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason)
|
||||||
|
@ -234,6 +234,9 @@ type TowerClient interface {
|
|||||||
type InterceptableHtlcForwarder interface {
|
type InterceptableHtlcForwarder interface {
|
||||||
// SetInterceptor sets a ForwardInterceptor.
|
// SetInterceptor sets a ForwardInterceptor.
|
||||||
SetInterceptor(interceptor ForwardInterceptor)
|
SetInterceptor(interceptor ForwardInterceptor)
|
||||||
|
|
||||||
|
// Resolve resolves an intercepted packet.
|
||||||
|
Resolve(res *FwdResolution) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForwardInterceptor is a function that is invoked from the switch for every
|
// ForwardInterceptor is a function that is invoked from the switch for every
|
||||||
@ -242,7 +245,7 @@ type InterceptableHtlcForwarder interface {
|
|||||||
// to resolve it manually later in case it is held.
|
// to resolve it manually later in case it is held.
|
||||||
// The return value indicates if this handler will take control of this forward
|
// The return value indicates if this handler will take control of this forward
|
||||||
// and resolve it later or let the switch execute its default behavior.
|
// and resolve it later or let the switch execute its default behavior.
|
||||||
type ForwardInterceptor func(InterceptedForward) bool
|
type ForwardInterceptor func(InterceptedPacket) error
|
||||||
|
|
||||||
// InterceptedPacket contains the relevant information for the interceptor about
|
// InterceptedPacket contains the relevant information for the interceptor about
|
||||||
// an htlc.
|
// an htlc.
|
||||||
|
@ -3140,32 +3140,29 @@ func getThreeHopEvents(channels *clusterChannels, htlcID uint64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mockForwardInterceptor struct {
|
type mockForwardInterceptor struct {
|
||||||
intercepted InterceptedForward
|
t *testing.T
|
||||||
|
|
||||||
|
interceptedChan chan InterceptedPacket
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockForwardInterceptor) InterceptForwardHtlc(
|
func (m *mockForwardInterceptor) InterceptForwardHtlc(
|
||||||
intercepted InterceptedForward) bool {
|
intercepted InterceptedPacket) error {
|
||||||
|
|
||||||
m.intercepted = intercepted
|
m.interceptedChan <- intercepted
|
||||||
return true
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockForwardInterceptor) settle(preimage lntypes.Preimage) error {
|
func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket {
|
||||||
return m.intercepted.Settle(preimage)
|
select {
|
||||||
}
|
case p := <-m.interceptedChan:
|
||||||
|
return p
|
||||||
|
|
||||||
func (m *mockForwardInterceptor) fail(reason []byte) error {
|
case <-time.After(time.Second):
|
||||||
return m.intercepted.Fail(reason)
|
require.Fail(m.t, "timeout")
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockForwardInterceptor) failWithCode(
|
return InterceptedPacket{}
|
||||||
code lnwire.FailCode) error {
|
}
|
||||||
|
|
||||||
return m.intercepted.FailWithCode(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockForwardInterceptor) resume() error {
|
|
||||||
return m.intercepted.Resume()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
|
func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
|
||||||
@ -3272,12 +3269,17 @@ func TestSwitchHoldForward(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardInterceptor := &mockForwardInterceptor{}
|
forwardInterceptor := &mockForwardInterceptor{
|
||||||
|
t: t,
|
||||||
|
interceptedChan: make(chan InterceptedPacket),
|
||||||
|
}
|
||||||
switchForwardInterceptor := NewInterceptableSwitch(s)
|
switchForwardInterceptor := NewInterceptableSwitch(s)
|
||||||
|
require.NoError(t, switchForwardInterceptor.Start())
|
||||||
|
|
||||||
switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
|
switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
|
||||||
linkQuit := make(chan struct{})
|
linkQuit := make(chan struct{})
|
||||||
|
|
||||||
// Test resume a hold forward
|
// Test resume a hold forward.
|
||||||
assertNumCircuits(t, s, 0, 0)
|
assertNumCircuits(t, s, 0, 0)
|
||||||
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
|
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
|
||||||
t.Fatalf("can't forward htlc packet: %v", err)
|
t.Fatalf("can't forward htlc packet: %v", err)
|
||||||
@ -3285,9 +3287,10 @@ func TestSwitchHoldForward(t *testing.T) {
|
|||||||
assertNumCircuits(t, s, 0, 0)
|
assertNumCircuits(t, s, 0, 0)
|
||||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
|
|
||||||
if err := forwardInterceptor.resume(); err != nil {
|
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
|
||||||
t.Fatalf("failed to resume forward")
|
Action: FwdActionResume,
|
||||||
}
|
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
|
||||||
|
}))
|
||||||
assertOutgoingLinkReceive(t, bobChannelLink, true)
|
assertOutgoingLinkReceive(t, bobChannelLink, true)
|
||||||
assertNumCircuits(t, s, 1, 1)
|
assertNumCircuits(t, s, 1, 1)
|
||||||
|
|
||||||
@ -3306,16 +3309,46 @@ func TestSwitchHoldForward(t *testing.T) {
|
|||||||
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||||
assertNumCircuits(t, s, 0, 0)
|
assertNumCircuits(t, s, 0, 0)
|
||||||
|
|
||||||
|
// Test resume a hold forward after disconnection.
|
||||||
|
err = switchForwardInterceptor.ForwardPackets(nil, ogPacket)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait until the packet is offered to the interceptor.
|
||||||
|
_ = forwardInterceptor.getIntercepted()
|
||||||
|
|
||||||
|
// No forward expected yet.
|
||||||
|
assertNumCircuits(t, s, 0, 0)
|
||||||
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
|
|
||||||
|
// Disconnect should resume the forwarding.
|
||||||
|
switchForwardInterceptor.SetInterceptor(nil)
|
||||||
|
|
||||||
|
assertOutgoingLinkReceive(t, bobChannelLink, true)
|
||||||
|
assertNumCircuits(t, s, 1, 1)
|
||||||
|
|
||||||
|
// Settle the htlc to close the circuit.
|
||||||
|
settle.outgoingHTLCID = 1
|
||||||
|
require.NoError(t, switchForwardInterceptor.ForwardPackets(nil, settle))
|
||||||
|
|
||||||
|
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||||
|
assertNumCircuits(t, s, 0, 0)
|
||||||
|
|
||||||
// Test failing a hold forward
|
// Test failing a hold forward
|
||||||
|
switchForwardInterceptor.SetInterceptor(
|
||||||
|
forwardInterceptor.InterceptForwardHtlc,
|
||||||
|
)
|
||||||
|
|
||||||
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
|
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
|
||||||
t.Fatalf("can't forward htlc packet: %v", err)
|
t.Fatalf("can't forward htlc packet: %v", err)
|
||||||
}
|
}
|
||||||
assertNumCircuits(t, s, 0, 0)
|
assertNumCircuits(t, s, 0, 0)
|
||||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
|
|
||||||
if err := forwardInterceptor.fail(nil); err != nil {
|
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
|
||||||
t.Fatalf("failed to cancel forward %v", err)
|
Action: FwdActionFail,
|
||||||
}
|
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
|
||||||
|
FailureCode: lnwire.CodeTemporaryChannelFailure,
|
||||||
|
}))
|
||||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||||
assertNumCircuits(t, s, 0, 0)
|
assertNumCircuits(t, s, 0, 0)
|
||||||
@ -3328,7 +3361,11 @@ func TestSwitchHoldForward(t *testing.T) {
|
|||||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
|
|
||||||
reason := lnwire.OpaqueReason([]byte{1, 2, 3})
|
reason := lnwire.OpaqueReason([]byte{1, 2, 3})
|
||||||
require.NoError(t, forwardInterceptor.fail(reason))
|
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
|
||||||
|
Action: FwdActionFail,
|
||||||
|
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
|
||||||
|
FailureMessage: reason,
|
||||||
|
}))
|
||||||
|
|
||||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
packet := assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
packet := assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||||
@ -3345,7 +3382,11 @@ func TestSwitchHoldForward(t *testing.T) {
|
|||||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
|
|
||||||
code := lnwire.CodeInvalidOnionKey
|
code := lnwire.CodeInvalidOnionKey
|
||||||
require.NoError(t, forwardInterceptor.failWithCode(code))
|
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
|
||||||
|
Action: FwdActionFail,
|
||||||
|
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
|
||||||
|
FailureCode: code,
|
||||||
|
}))
|
||||||
|
|
||||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
packet = assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
packet = assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||||
@ -3369,12 +3410,16 @@ func TestSwitchHoldForward(t *testing.T) {
|
|||||||
assertNumCircuits(t, s, 0, 0)
|
assertNumCircuits(t, s, 0, 0)
|
||||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
|
|
||||||
if err := forwardInterceptor.settle(preimage); err != nil {
|
require.NoError(t, switchForwardInterceptor.resolve(&FwdResolution{
|
||||||
t.Fatal("failed to cancel forward")
|
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
|
||||||
}
|
Action: FwdActionSettle,
|
||||||
|
Preimage: preimage,
|
||||||
|
}))
|
||||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||||
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||||
assertNumCircuits(t, s, 0, 0)
|
assertNumCircuits(t, s, 0, 0)
|
||||||
|
|
||||||
|
require.NoError(t, switchForwardInterceptor.Stop())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestSwitchDustForwarding tests that the switch properly fails HTLC's which
|
// TestSwitchDustForwarding tests that the switch properly fails HTLC's which
|
||||||
|
@ -2,7 +2,6 @@ package routerrpc
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
"github.com/lightningnetwork/lnd/htlcswitch"
|
"github.com/lightningnetwork/lnd/htlcswitch"
|
||||||
@ -27,36 +26,19 @@ var (
|
|||||||
// interceptor streaming session.
|
// interceptor streaming session.
|
||||||
// It is created when the stream opens and disconnects when the stream closes.
|
// It is created when the stream opens and disconnects when the stream closes.
|
||||||
type forwardInterceptor struct {
|
type forwardInterceptor struct {
|
||||||
// server is the Server reference
|
|
||||||
server *Server
|
|
||||||
|
|
||||||
// holdForwards is a map of current hold forwards and their corresponding
|
|
||||||
// ForwardResolver.
|
|
||||||
holdForwards map[channeldb.CircuitKey]htlcswitch.InterceptedForward
|
|
||||||
|
|
||||||
// stream is the bidirectional RPC stream
|
// stream is the bidirectional RPC stream
|
||||||
stream Router_HtlcInterceptorServer
|
stream Router_HtlcInterceptorServer
|
||||||
|
|
||||||
// quit is a channel that is closed when this forwardInterceptor is shutting
|
htlcSwitch htlcswitch.InterceptableHtlcForwarder
|
||||||
// down.
|
|
||||||
quit chan struct{}
|
|
||||||
|
|
||||||
// intercepted is where we stream all intercepted packets coming from
|
|
||||||
// the switch.
|
|
||||||
intercepted chan htlcswitch.InterceptedForward
|
|
||||||
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newForwardInterceptor creates a new forwardInterceptor.
|
// newForwardInterceptor creates a new forwardInterceptor.
|
||||||
func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) *forwardInterceptor {
|
func newForwardInterceptor(htlcSwitch htlcswitch.InterceptableHtlcForwarder,
|
||||||
|
stream Router_HtlcInterceptorServer) *forwardInterceptor {
|
||||||
|
|
||||||
return &forwardInterceptor{
|
return &forwardInterceptor{
|
||||||
server: server,
|
htlcSwitch: htlcSwitch,
|
||||||
stream: stream,
|
stream: stream,
|
||||||
holdForwards: make(
|
|
||||||
map[channeldb.CircuitKey]htlcswitch.InterceptedForward),
|
|
||||||
quit: make(chan struct{}),
|
|
||||||
intercepted: make(chan htlcswitch.InterceptedForward),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,42 +49,18 @@ func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer)
|
|||||||
// To coordinate all this and make sure it is safe for concurrent access all
|
// To coordinate all this and make sure it is safe for concurrent access all
|
||||||
// packets are sent to the main where they are handled.
|
// packets are sent to the main where they are handled.
|
||||||
func (r *forwardInterceptor) run() error {
|
func (r *forwardInterceptor) run() error {
|
||||||
// make sure we disconnect and resolves all remaining packets if any.
|
|
||||||
defer r.onDisconnect()
|
|
||||||
|
|
||||||
// Register our interceptor so we receive all forwarded packets.
|
// Register our interceptor so we receive all forwarded packets.
|
||||||
interceptableForwarder := r.server.cfg.RouterBackend.InterceptableForwarder
|
r.htlcSwitch.SetInterceptor(r.onIntercept)
|
||||||
interceptableForwarder.SetInterceptor(r.onIntercept)
|
defer r.htlcSwitch.SetInterceptor(nil)
|
||||||
defer interceptableForwarder.SetInterceptor(nil)
|
|
||||||
|
|
||||||
// start a go routine that reads client resolutions.
|
|
||||||
errChan := make(chan error)
|
|
||||||
resolutionRequests := make(chan *ForwardHtlcInterceptResponse)
|
|
||||||
r.wg.Add(1)
|
|
||||||
go r.readClientResponses(resolutionRequests, errChan)
|
|
||||||
|
|
||||||
// run the main loop that synchronizes both sides input into one go routine.
|
|
||||||
for {
|
for {
|
||||||
select {
|
resp, err := r.stream.Recv()
|
||||||
case intercepted := <-r.intercepted:
|
if err != nil {
|
||||||
log.Tracef("sending intercepted packet to client %v", intercepted)
|
|
||||||
// in case we couldn't forward we exit the loop and drain the
|
|
||||||
// current interceptor as this indicates on a connection problem.
|
|
||||||
if err := r.holdAndForwardToClient(intercepted); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case resolution := <-resolutionRequests:
|
|
||||||
log.Tracef("resolving intercepted packet %v", resolution)
|
if err := r.resolveFromClient(resp); err != nil {
|
||||||
// in case we couldn't resolve we just add a log line since this
|
|
||||||
// does not indicate on any connection problem.
|
|
||||||
if err := r.resolveFromClient(resolution); err != nil {
|
|
||||||
log.Warnf("client resolution of intercepted "+
|
|
||||||
"packet failed %v", err)
|
|
||||||
}
|
|
||||||
case err := <-errChan:
|
|
||||||
return err
|
return err
|
||||||
case <-r.server.quit:
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -111,54 +69,14 @@ func (r *forwardInterceptor) run() error {
|
|||||||
// packet. Our interceptor makes sure we hold the packet and then signal to the
|
// packet. Our interceptor makes sure we hold the packet and then signal to the
|
||||||
// main loop to handle the packet. We only return true if we were able
|
// main loop to handle the packet. We only return true if we were able
|
||||||
// to deliver the packet to the main loop.
|
// to deliver the packet to the main loop.
|
||||||
func (r *forwardInterceptor) onIntercept(p htlcswitch.InterceptedForward) bool {
|
func (r *forwardInterceptor) onIntercept(
|
||||||
select {
|
htlc htlcswitch.InterceptedPacket) error {
|
||||||
case r.intercepted <- p:
|
|
||||||
return true
|
|
||||||
case <-r.quit:
|
|
||||||
return false
|
|
||||||
case <-r.server.quit:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *forwardInterceptor) readClientResponses(
|
log.Tracef("Sending intercepted packet to client %v", htlc)
|
||||||
resolutionChan chan *ForwardHtlcInterceptResponse, errChan chan error) {
|
|
||||||
|
|
||||||
defer r.wg.Done()
|
|
||||||
for {
|
|
||||||
resp, err := r.stream.Recv()
|
|
||||||
if err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we have the response from the RPC client, send it to
|
|
||||||
// the responses chan.
|
|
||||||
select {
|
|
||||||
case resolutionChan <- resp:
|
|
||||||
case <-r.quit:
|
|
||||||
return
|
|
||||||
case <-r.server.quit:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// holdAndForwardToClient forwards the intercepted htlc to the client.
|
|
||||||
func (r *forwardInterceptor) holdAndForwardToClient(
|
|
||||||
forward htlcswitch.InterceptedForward) error {
|
|
||||||
|
|
||||||
htlc := forward.Packet()
|
|
||||||
inKey := htlc.IncomingCircuit
|
inKey := htlc.IncomingCircuit
|
||||||
|
|
||||||
// Ignore already held htlcs.
|
|
||||||
if _, ok := r.holdForwards[inKey]; ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// First hold the forward, then send to client.
|
// First hold the forward, then send to client.
|
||||||
r.holdForwards[inKey] = forward
|
|
||||||
interceptionRequest := &ForwardHtlcInterceptRequest{
|
interceptionRequest := &ForwardHtlcInterceptRequest{
|
||||||
IncomingCircuitKey: &CircuitKey{
|
IncomingCircuitKey: &CircuitKey{
|
||||||
ChanId: inKey.ChanID.ToUint64(),
|
ChanId: inKey.ChanID.ToUint64(),
|
||||||
@ -181,20 +99,19 @@ func (r *forwardInterceptor) holdAndForwardToClient(
|
|||||||
func (r *forwardInterceptor) resolveFromClient(
|
func (r *forwardInterceptor) resolveFromClient(
|
||||||
in *ForwardHtlcInterceptResponse) error {
|
in *ForwardHtlcInterceptResponse) error {
|
||||||
|
|
||||||
|
log.Tracef("Resolving intercepted packet %v", in)
|
||||||
|
|
||||||
circuitKey := channeldb.CircuitKey{
|
circuitKey := channeldb.CircuitKey{
|
||||||
ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId),
|
ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId),
|
||||||
HtlcID: in.IncomingCircuitKey.HtlcId,
|
HtlcID: in.IncomingCircuitKey.HtlcId,
|
||||||
}
|
}
|
||||||
var interceptedForward htlcswitch.InterceptedForward
|
|
||||||
interceptedForward, ok := r.holdForwards[circuitKey]
|
|
||||||
if !ok {
|
|
||||||
return ErrFwdNotExists
|
|
||||||
}
|
|
||||||
delete(r.holdForwards, circuitKey)
|
|
||||||
|
|
||||||
switch in.Action {
|
switch in.Action {
|
||||||
case ResolveHoldForwardAction_RESUME:
|
case ResolveHoldForwardAction_RESUME:
|
||||||
return interceptedForward.Resume()
|
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
|
||||||
|
Key: circuitKey,
|
||||||
|
Action: htlcswitch.FwdActionResume,
|
||||||
|
})
|
||||||
|
|
||||||
case ResolveHoldForwardAction_FAIL:
|
case ResolveHoldForwardAction_FAIL:
|
||||||
// Fail with an encrypted reason.
|
// Fail with an encrypted reason.
|
||||||
@ -219,7 +136,11 @@ func (r *forwardInterceptor) resolveFromClient(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return interceptedForward.Fail(in.FailureMessage)
|
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
|
||||||
|
Key: circuitKey,
|
||||||
|
Action: htlcswitch.FwdActionFail,
|
||||||
|
FailureMessage: in.FailureMessage,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var code lnwire.FailCode
|
var code lnwire.FailCode
|
||||||
@ -244,14 +165,11 @@ func (r *forwardInterceptor) resolveFromClient(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := interceptedForward.FailWithCode(code)
|
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
|
||||||
if err == htlcswitch.ErrUnsupportedFailureCode {
|
Key: circuitKey,
|
||||||
return status.Errorf(
|
Action: htlcswitch.FwdActionFail,
|
||||||
codes.InvalidArgument, err.Error(),
|
FailureCode: code,
|
||||||
)
|
})
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
|
|
||||||
case ResolveHoldForwardAction_SETTLE:
|
case ResolveHoldForwardAction_SETTLE:
|
||||||
if in.Preimage == nil {
|
if in.Preimage == nil {
|
||||||
@ -261,7 +179,12 @@ func (r *forwardInterceptor) resolveFromClient(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return interceptedForward.Settle(preimage)
|
|
||||||
|
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
|
||||||
|
Key: circuitKey,
|
||||||
|
Action: htlcswitch.FwdActionSettle,
|
||||||
|
Preimage: preimage,
|
||||||
|
})
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return status.Errorf(
|
return status.Errorf(
|
||||||
@ -270,20 +193,3 @@ func (r *forwardInterceptor) resolveFromClient(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// onDisconnect removes all previousely held forwards from
|
|
||||||
// the store. Before they are removed it ensure to resume as the default
|
|
||||||
// behavior.
|
|
||||||
func (r *forwardInterceptor) onDisconnect() {
|
|
||||||
// Then close the channel so all go routine will exit.
|
|
||||||
close(r.quit)
|
|
||||||
|
|
||||||
log.Infof("RPC interceptor disconnected, resolving held packets")
|
|
||||||
for key, forward := range r.holdForwards {
|
|
||||||
if err := forward.Resume(); err != nil {
|
|
||||||
log.Errorf("failed to resume hold forward %v", err)
|
|
||||||
}
|
|
||||||
delete(r.holdForwards, key)
|
|
||||||
}
|
|
||||||
r.wg.Wait()
|
|
||||||
}
|
|
||||||
|
@ -890,7 +890,9 @@ func (s *Server) HtlcInterceptor(stream Router_HtlcInterceptorServer) error {
|
|||||||
defer atomic.CompareAndSwapInt32(&s.forwardInterceptorActive, 1, 0)
|
defer atomic.CompareAndSwapInt32(&s.forwardInterceptorActive, 1, 0)
|
||||||
|
|
||||||
// run the forward interceptor.
|
// run the forward interceptor.
|
||||||
return newForwardInterceptor(s, stream).run()
|
return newForwardInterceptor(
|
||||||
|
s.cfg.RouterBackend.InterceptableForwarder, stream,
|
||||||
|
).run()
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractOutPoint(req *UpdateChanStatusRequest) (*wire.OutPoint, error) {
|
func extractOutPoint(req *UpdateChanStatusRequest) (*wire.OutPoint, error) {
|
||||||
|
@ -1786,6 +1786,12 @@ func (s *server) Start() error {
|
|||||||
}
|
}
|
||||||
cleanup = cleanup.add(s.htlcSwitch.Stop)
|
cleanup = cleanup.add(s.htlcSwitch.Stop)
|
||||||
|
|
||||||
|
if err := s.interceptableSwitch.Start(); err != nil {
|
||||||
|
startErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cleanup = cleanup.add(s.interceptableSwitch.Stop)
|
||||||
|
|
||||||
if err := s.chainArb.Start(); err != nil {
|
if err := s.chainArb.Start(); err != nil {
|
||||||
startErr = err
|
startErr = err
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user