htlcswitch: interceptor expiry check

This commit is contained in:
Joost Jager
2022-01-28 12:56:17 +01:00
parent bae0b6bdf9
commit 13dff2fb3e
6 changed files with 143 additions and 14 deletions

View File

@@ -55,6 +55,10 @@ type InterceptableSwitch struct {
// holdForwards keeps track of outstanding intercepted forwards.
holdForwards map[channeldb.CircuitKey]InterceptedForward
// cltvRejectDelta defines the number of blocks before the expiry of the
// htlc where we no longer intercept it and instead cancel it back.
cltvRejectDelta uint32
wg sync.WaitGroup
quit chan struct{}
}
@@ -106,7 +110,7 @@ type fwdResolution struct {
}
// NewInterceptableSwitch returns an instance of InterceptableSwitch.
func NewInterceptableSwitch(s *Switch,
func NewInterceptableSwitch(s *Switch, cltvRejectDelta uint32,
requireInterceptor bool) *InterceptableSwitch {
return &InterceptableSwitch{
@@ -116,6 +120,7 @@ func NewInterceptableSwitch(s *Switch,
holdForwards: make(map[channeldb.CircuitKey]InterceptedForward),
resolutionChan: make(chan *fwdResolution),
requireInterceptor: requireInterceptor,
cltvRejectDelta: cltvRejectDelta,
quit: make(chan struct{}),
}
@@ -337,6 +342,27 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
htlcSwitch: s.htlcSwitch,
}
// Handle forwards that are too close to expiry.
handled, err := s.handleExpired(intercepted)
if err != nil {
log.Errorf("Error handling intercepted htlc "+
"that expires too soon: circuit=%v, "+
"incoming_timeout=%v, err=%v",
packet.inKey(), packet.incomingTimeout, err)
// Return false so that the packet is offered as normal
// to the switch. This isn't ideal because interception
// may be configured as always-on and is skipped now.
// Returning true isn't great either, because the htlc
// will remain stuck and potentially force-close the
// channel. But in the end, we should never get here, so
// the actual return value doesn't matter that much.
return false
}
if handled {
return true
}
if s.interceptor == nil && !isReplay {
// There is no interceptor registered, we are in
// interceptor-required mode, and this is a new packet
@@ -370,6 +396,32 @@ func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
}
}
// handleExpired checks that the htlc isn't too close to the channel
// force-close broadcast height. If it is, it is cancelled back.
func (s *InterceptableSwitch) handleExpired(fwd *interceptedForward) (
bool, error) {
height := s.htlcSwitch.BestHeight()
if fwd.packet.incomingTimeout >= height+s.cltvRejectDelta {
return false, nil
}
log.Debugf("Interception rejected because htlc "+
"expires too soon: circuit=%v, "+
"height=%v, incoming_timeout=%v",
fwd.packet.inKey(), height,
fwd.packet.incomingTimeout)
err := fwd.FailWithCode(
lnwire.CodeExpiryTooSoon,
)
if err != nil {
return false, err
}
return true, nil
}
// interceptedForward implements the InterceptedForward interface.
// It is passed from the switch to external interceptors that are interested
// in holding forwards and resolve them manually.
@@ -450,6 +502,16 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error {
failureMsg = lnwire.NewTemporaryChannelFailure(update)
case lnwire.CodeExpiryTooSoon:
update, err := f.htlcSwitch.cfg.FetchLastChannelUpdate(
f.packet.incomingChanID,
)
if err != nil {
return err
}
failureMsg = lnwire.NewExpiryTooSoon(*update)
default:
return ErrUnsupportedFailureCode
}

View File

@@ -185,7 +185,7 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error)
events: make(map[time.Time]channeldb.ForwardingEvent),
},
FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return nil, nil
return &lnwire.ChannelUpdate{}, nil
},
Notifier: &mock.ChainNotifier{
SpendChan: make(chan *chainntnfs.SpendDetail),

View File

@@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
@@ -3167,10 +3168,12 @@ func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket {
func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
if s.circuits.NumPending() != pending {
t.Fatal("wrong amount of half circuits")
t.Fatalf("wrong amount of half circuits, expected %v but "+
"got %v", pending, s.circuits.NumPending())
}
if s.circuits.NumOpen() != opened {
t.Fatal("wrong amount of circuits")
t.Fatalf("wrong amount of circuits, expected %v but got %v",
opened, s.circuits.NumOpen())
}
}
@@ -3197,6 +3200,16 @@ func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
return nil
}
func assertOutgoingLinkReceiveIntercepted(t *testing.T,
targetLink *mockChannelLink) {
select {
case <-targetLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
}
func TestSwitchHoldForward(t *testing.T) {
t.Parallel()
@@ -3259,14 +3272,17 @@ func TestSwitchHoldForward(t *testing.T) {
onionBlob := [1366]byte{4, 5, 6}
incomingHtlcID := uint64(0)
const cltvRejectDelta = 13
createTestPacket := func() *htlcPacket {
incomingHtlcID++
return &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: incomingHtlcID,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: incomingHtlcID,
incomingTimeout: testStartingHeight + cltvRejectDelta + 1,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
@@ -3290,12 +3306,55 @@ func TestSwitchHoldForward(t *testing.T) {
t: t,
interceptedChan: make(chan InterceptedPacket),
}
switchForwardInterceptor := NewInterceptableSwitch(s, false)
switchForwardInterceptor := NewInterceptableSwitch(
s, cltvRejectDelta, false,
)
require.NoError(t, switchForwardInterceptor.Start())
switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
linkQuit := make(chan struct{})
// Test a forward that expires too soon.
packet := createTestPacket()
packet.incomingTimeout = testStartingHeight + cltvRejectDelta - 1
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
if err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceiveIntercepted(t, aliceChannelLink)
assertNumCircuits(t, s, 0, 0)
// Test a forward that expires too soon and can't be failed.
packet = createTestPacket()
packet.incomingTimeout = testStartingHeight + cltvRejectDelta - 1
// Simulate an error during the composition of the failure message.
currentCallback := s.cfg.FetchLastChannelUpdate
s.cfg.FetchLastChannelUpdate = func(
lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return nil, errors.New("cannot fetch update")
}
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
if err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
receivedPkt := assertOutgoingLinkReceive(t, bobChannelLink, true)
assertNumCircuits(t, s, 1, 1)
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false,
createSettlePacket(receivedPkt.outgoingHTLCID),
))
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
s.cfg.FetchLastChannelUpdate = currentCallback
// Test resume a hold forward.
assertNumCircuits(t, s, 0, 0)
err = switchForwardInterceptor.ForwardPackets(
@@ -3310,7 +3369,7 @@ func TestSwitchHoldForward(t *testing.T) {
Action: FwdActionResume,
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
}))
receivedPkt := assertOutgoingLinkReceive(t, bobChannelLink, true)
receivedPkt = assertOutgoingLinkReceive(t, bobChannelLink, true)
assertNumCircuits(t, s, 1, 1)
// settling the htlc to close the circuit.
@@ -3387,7 +3446,7 @@ func TestSwitchHoldForward(t *testing.T) {
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
packet := assertOutgoingLinkReceive(t, aliceChannelLink, true)
packet = assertOutgoingLinkReceive(t, aliceChannelLink, true)
require.Equal(t, reason, packet.htlc.(*lnwire.UpdateFailHTLC).Reason)
@@ -3443,7 +3502,7 @@ func TestSwitchHoldForward(t *testing.T) {
require.NoError(t, switchForwardInterceptor.Stop())
// Test always-on interception.
switchForwardInterceptor = NewInterceptableSwitch(s, true)
switchForwardInterceptor = NewInterceptableSwitch(s, cltvRejectDelta, true)
require.NoError(t, switchForwardInterceptor.Start())
// Forward a fresh packet. It is expected to be failed immediately,