mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-08 14:57:38 +02:00
htlcswitch: interceptor expiry check
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/btcsuite/btcd/btcutil"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
@@ -3167,10 +3168,12 @@ func (m *mockForwardInterceptor) getIntercepted() InterceptedPacket {
|
||||
|
||||
func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
|
||||
if s.circuits.NumPending() != pending {
|
||||
t.Fatal("wrong amount of half circuits")
|
||||
t.Fatalf("wrong amount of half circuits, expected %v but "+
|
||||
"got %v", pending, s.circuits.NumPending())
|
||||
}
|
||||
if s.circuits.NumOpen() != opened {
|
||||
t.Fatal("wrong amount of circuits")
|
||||
t.Fatalf("wrong amount of circuits, expected %v but got %v",
|
||||
opened, s.circuits.NumOpen())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3197,6 +3200,16 @@ func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertOutgoingLinkReceiveIntercepted(t *testing.T,
|
||||
targetLink *mockChannelLink) {
|
||||
|
||||
select {
|
||||
case <-targetLink.packets:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request was not propagated to destination")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchHoldForward(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -3259,14 +3272,17 @@ func TestSwitchHoldForward(t *testing.T) {
|
||||
onionBlob := [1366]byte{4, 5, 6}
|
||||
incomingHtlcID := uint64(0)
|
||||
|
||||
const cltvRejectDelta = 13
|
||||
|
||||
createTestPacket := func() *htlcPacket {
|
||||
incomingHtlcID++
|
||||
|
||||
return &htlcPacket{
|
||||
incomingChanID: aliceChannelLink.ShortChanID(),
|
||||
incomingHTLCID: incomingHtlcID,
|
||||
outgoingChanID: bobChannelLink.ShortChanID(),
|
||||
obfuscator: NewMockObfuscator(),
|
||||
incomingChanID: aliceChannelLink.ShortChanID(),
|
||||
incomingHTLCID: incomingHtlcID,
|
||||
incomingTimeout: testStartingHeight + cltvRejectDelta + 1,
|
||||
outgoingChanID: bobChannelLink.ShortChanID(),
|
||||
obfuscator: NewMockObfuscator(),
|
||||
htlc: &lnwire.UpdateAddHTLC{
|
||||
PaymentHash: rhash,
|
||||
Amount: 1,
|
||||
@@ -3290,12 +3306,55 @@ func TestSwitchHoldForward(t *testing.T) {
|
||||
t: t,
|
||||
interceptedChan: make(chan InterceptedPacket),
|
||||
}
|
||||
switchForwardInterceptor := NewInterceptableSwitch(s, false)
|
||||
switchForwardInterceptor := NewInterceptableSwitch(
|
||||
s, cltvRejectDelta, false,
|
||||
)
|
||||
require.NoError(t, switchForwardInterceptor.Start())
|
||||
|
||||
switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
|
||||
linkQuit := make(chan struct{})
|
||||
|
||||
// Test a forward that expires too soon.
|
||||
packet := createTestPacket()
|
||||
packet.incomingTimeout = testStartingHeight + cltvRejectDelta - 1
|
||||
|
||||
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
|
||||
if err != nil {
|
||||
t.Fatalf("can't forward htlc packet: %v", err)
|
||||
}
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
assertOutgoingLinkReceiveIntercepted(t, aliceChannelLink)
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
|
||||
// Test a forward that expires too soon and can't be failed.
|
||||
packet = createTestPacket()
|
||||
packet.incomingTimeout = testStartingHeight + cltvRejectDelta - 1
|
||||
|
||||
// Simulate an error during the composition of the failure message.
|
||||
currentCallback := s.cfg.FetchLastChannelUpdate
|
||||
s.cfg.FetchLastChannelUpdate = func(
|
||||
lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
|
||||
|
||||
return nil, errors.New("cannot fetch update")
|
||||
}
|
||||
|
||||
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
|
||||
if err != nil {
|
||||
t.Fatalf("can't forward htlc packet: %v", err)
|
||||
}
|
||||
receivedPkt := assertOutgoingLinkReceive(t, bobChannelLink, true)
|
||||
assertNumCircuits(t, s, 1, 1)
|
||||
|
||||
require.NoError(t, switchForwardInterceptor.ForwardPackets(
|
||||
linkQuit, false,
|
||||
createSettlePacket(receivedPkt.outgoingHTLCID),
|
||||
))
|
||||
|
||||
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
|
||||
s.cfg.FetchLastChannelUpdate = currentCallback
|
||||
|
||||
// Test resume a hold forward.
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
err = switchForwardInterceptor.ForwardPackets(
|
||||
@@ -3310,7 +3369,7 @@ func TestSwitchHoldForward(t *testing.T) {
|
||||
Action: FwdActionResume,
|
||||
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
|
||||
}))
|
||||
receivedPkt := assertOutgoingLinkReceive(t, bobChannelLink, true)
|
||||
receivedPkt = assertOutgoingLinkReceive(t, bobChannelLink, true)
|
||||
assertNumCircuits(t, s, 1, 1)
|
||||
|
||||
// settling the htlc to close the circuit.
|
||||
@@ -3387,7 +3446,7 @@ func TestSwitchHoldForward(t *testing.T) {
|
||||
}))
|
||||
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
packet := assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||
packet = assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||
|
||||
require.Equal(t, reason, packet.htlc.(*lnwire.UpdateFailHTLC).Reason)
|
||||
|
||||
@@ -3443,7 +3502,7 @@ func TestSwitchHoldForward(t *testing.T) {
|
||||
require.NoError(t, switchForwardInterceptor.Stop())
|
||||
|
||||
// Test always-on interception.
|
||||
switchForwardInterceptor = NewInterceptableSwitch(s, true)
|
||||
switchForwardInterceptor = NewInterceptableSwitch(s, cltvRejectDelta, true)
|
||||
require.NoError(t, switchForwardInterceptor.Start())
|
||||
|
||||
// Forward a fresh packet. It is expected to be failed immediately,
|
||||
|
Reference in New Issue
Block a user