htlcswitch/test: create interceptableSwitchTestContext

Refactor to prepare for adding more tests.
This commit is contained in:
Joost Jager
2022-08-16 14:04:22 +02:00
parent 9c063db698
commit 93a7cab46e

View File

@@ -3761,8 +3761,23 @@ func assertOutgoingLinkReceiveIntercepted(t *testing.T,
}
}
func TestSwitchHoldForward(t *testing.T) {
t.Parallel()
type interceptableSwitchTestContext struct {
t *testing.T
preimage [sha256.Size]byte
rhash [32]byte
onionBlob [1366]byte
incomingHtlcID uint64
cltvRejectDelta uint32
forwardInterceptor *mockForwardInterceptor
aliceChannelLink *mockChannelLink
bobChannelLink *mockChannelLink
s *Switch
}
func newInterceptableSwitchTestContext(
t *testing.T) *interceptableSwitchTestContext {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
@@ -3787,12 +3802,6 @@ func TestSwitchHoldForward(t *testing.T) {
t.Fatalf("unable to start switch: %v", err)
}
defer func() {
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
}()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, emptyScid, alicePeer, true, false,
false, false,
@@ -3808,48 +3817,69 @@ func TestSwitchHoldForward(t *testing.T) {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := sha256.Sum256(preimage[:])
onionBlob := [1366]byte{4, 5, 6}
incomingHtlcID := uint64(0)
const cltvRejectDelta = 13
createTestPacket := func() *htlcPacket {
incomingHtlcID++
return &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: incomingHtlcID,
incomingTimeout: testStartingHeight + cltvRejectDelta + 1,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
OnionBlob: onionBlob,
},
}
}
createSettlePacket := func(outgoingHTLCID uint64) *htlcPacket {
return &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: outgoingHTLCID,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
}
forwardInterceptor := &mockForwardInterceptor{
ctx := &interceptableSwitchTestContext{
t: t,
interceptedChan: make(chan InterceptedPacket),
preimage: preimage,
rhash: sha256.Sum256(preimage[:]),
onionBlob: [1366]byte{4, 5, 6},
incomingHtlcID: uint64(0),
cltvRejectDelta: 13,
forwardInterceptor: &mockForwardInterceptor{
t: t,
interceptedChan: make(chan InterceptedPacket),
},
aliceChannelLink: aliceChannelLink,
bobChannelLink: bobChannelLink,
s: s,
}
return ctx
}
func (c *interceptableSwitchTestContext) createTestPacket() *htlcPacket {
c.incomingHtlcID++
return &htlcPacket{
incomingChanID: c.aliceChannelLink.ShortChanID(),
incomingHTLCID: c.incomingHtlcID,
incomingTimeout: testStartingHeight + c.cltvRejectDelta + 1,
outgoingChanID: c.bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: c.rhash,
Amount: 1,
OnionBlob: c.onionBlob,
},
}
}
func (c *interceptableSwitchTestContext) finish() {
if err := c.s.Stop(); err != nil {
c.t.Fatalf(err.Error())
}
}
func (c *interceptableSwitchTestContext) createSettlePacket(
outgoingHTLCID uint64) *htlcPacket {
return &htlcPacket{
outgoingChanID: c.bobChannelLink.ShortChanID(),
outgoingHTLCID: outgoingHTLCID,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: c.preimage,
},
}
}
func TestSwitchHoldForward(t *testing.T) {
t.Parallel()
c := newInterceptableSwitchTestContext(t)
defer c.finish()
notifier := &mock.ChainNotifier{
EpochChan: make(chan *chainntnfs.BlockEpoch, 1),
}
@@ -3857,33 +3887,33 @@ func TestSwitchHoldForward(t *testing.T) {
switchForwardInterceptor := NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: s,
CltvRejectDelta: cltvRejectDelta,
Switch: c.s,
CltvRejectDelta: c.cltvRejectDelta,
Notifier: notifier,
},
)
require.NoError(t, switchForwardInterceptor.Start())
switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
switchForwardInterceptor.SetInterceptor(c.forwardInterceptor.InterceptForwardHtlc)
linkQuit := make(chan struct{})
// Test a forward that expires too soon.
packet := createTestPacket()
packet.incomingTimeout = testStartingHeight + cltvRejectDelta - 1
packet := c.createTestPacket()
packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
err := switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
require.NoError(t, err, "can't forward htlc packet")
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceiveIntercepted(t, aliceChannelLink)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceiveIntercepted(t, c.aliceChannelLink)
assertNumCircuits(t, c.s, 0, 0)
// Test a forward that expires too soon and can't be failed.
packet = createTestPacket()
packet.incomingTimeout = testStartingHeight + cltvRejectDelta - 1
packet = c.createTestPacket()
packet.incomingTimeout = testStartingHeight + c.cltvRejectDelta - 1
// Simulate an error during the composition of the failure message.
currentCallback := s.cfg.FetchLastChannelUpdate
s.cfg.FetchLastChannelUpdate = func(
currentCallback := c.s.cfg.FetchLastChannelUpdate
c.s.cfg.FetchLastChannelUpdate = func(
lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return nil, errors.New("cannot fetch update")
@@ -3891,137 +3921,137 @@ func TestSwitchHoldForward(t *testing.T) {
err = switchForwardInterceptor.ForwardPackets(linkQuit, false, packet)
require.NoError(t, err, "can't forward htlc packet")
receivedPkt := assertOutgoingLinkReceive(t, bobChannelLink, true)
assertNumCircuits(t, s, 1, 1)
receivedPkt := assertOutgoingLinkReceive(t, c.bobChannelLink, true)
assertNumCircuits(t, c.s, 1, 1)
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false,
createSettlePacket(receivedPkt.outgoingHTLCID),
c.createSettlePacket(receivedPkt.outgoingHTLCID),
))
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
s.cfg.FetchLastChannelUpdate = currentCallback
c.s.cfg.FetchLastChannelUpdate = currentCallback
// Test resume a hold forward.
assertNumCircuits(t, s, 0, 0)
assertNumCircuits(t, c.s, 0, 0)
err = switchForwardInterceptor.ForwardPackets(
linkQuit, false, createTestPacket(),
linkQuit, false, c.createTestPacket(),
)
require.NoError(t, err)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionResume,
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
}))
receivedPkt = assertOutgoingLinkReceive(t, bobChannelLink, true)
assertNumCircuits(t, s, 1, 1)
receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true)
assertNumCircuits(t, c.s, 1, 1)
// settling the htlc to close the circuit.
err = switchForwardInterceptor.ForwardPackets(
linkQuit, false,
createSettlePacket(receivedPkt.outgoingHTLCID),
c.createSettlePacket(receivedPkt.outgoingHTLCID),
)
require.NoError(t, err)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Test resume a hold forward after disconnection.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, createTestPacket(),
linkQuit, false, c.createTestPacket(),
))
// Wait until the packet is offered to the interceptor.
_ = forwardInterceptor.getIntercepted()
_ = c.forwardInterceptor.getIntercepted()
// No forward expected yet.
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
// Disconnect should resume the forwarding.
switchForwardInterceptor.SetInterceptor(nil)
receivedPkt = assertOutgoingLinkReceive(t, bobChannelLink, true)
assertNumCircuits(t, s, 1, 1)
receivedPkt = assertOutgoingLinkReceive(t, c.bobChannelLink, true)
assertNumCircuits(t, c.s, 1, 1)
// Settle the htlc to close the circuit.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false,
createSettlePacket(receivedPkt.outgoingHTLCID),
c.createSettlePacket(receivedPkt.outgoingHTLCID),
))
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Test failing a hold forward
switchForwardInterceptor.SetInterceptor(
forwardInterceptor.InterceptForwardHtlc,
c.forwardInterceptor.InterceptForwardHtlc,
)
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, createTestPacket(),
linkQuit, false, c.createTestPacket(),
))
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionFail,
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
FailureCode: lnwire.CodeTemporaryChannelFailure,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Test failing a hold forward with a failure message.
require.NoError(t,
switchForwardInterceptor.ForwardPackets(
linkQuit, false, createTestPacket(),
linkQuit, false, c.createTestPacket(),
),
)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
reason := lnwire.OpaqueReason([]byte{1, 2, 3})
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionFail,
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
FailureMessage: reason,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
packet = assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
require.Equal(t, reason, packet.htlc.(*lnwire.UpdateFailHTLC).Reason)
assertNumCircuits(t, s, 0, 0)
assertNumCircuits(t, c.s, 0, 0)
// Test failing a hold forward with a malformed htlc failure.
err = switchForwardInterceptor.ForwardPackets(
linkQuit, false, createTestPacket(),
linkQuit, false, c.createTestPacket(),
)
require.NoError(t, err)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
code := lnwire.CodeInvalidOnionKey
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Action: FwdActionFail,
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
FailureCode: code,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
packet = assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
packet = assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
failPacket := packet.htlc.(*lnwire.UpdateFailHTLC)
shaOnionBlob := sha256.Sum256(onionBlob[:])
shaOnionBlob := sha256.Sum256(c.onionBlob[:])
expectedFailure := &lnwire.FailInvalidOnionKey{
OnionSHA256: shaOnionBlob,
}
@@ -4030,23 +4060,23 @@ func TestSwitchHoldForward(t *testing.T) {
assert.Equal(t, lnwire.OpaqueReason(b.Bytes()), failPacket.Reason)
assertNumCircuits(t, s, 0, 0)
assertNumCircuits(t, c.s, 0, 0)
// Test settling a hold forward
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, createTestPacket(),
linkQuit, false, c.createTestPacket(),
))
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertNumCircuits(t, c.s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Key: forwardInterceptor.getIntercepted().IncomingCircuit,
Key: c.forwardInterceptor.getIntercepted().IncomingCircuit,
Action: FwdActionSettle,
Preimage: preimage,
Preimage: c.preimage,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
require.NoError(t, switchForwardInterceptor.Stop())
@@ -4058,8 +4088,8 @@ func TestSwitchHoldForward(t *testing.T) {
switchForwardInterceptor = NewInterceptableSwitch(
&InterceptableSwitchConfig{
Switch: s,
CltvRejectDelta: cltvRejectDelta,
Switch: c.s,
CltvRejectDelta: c.cltvRejectDelta,
RequireInterceptor: true,
Notifier: notifier,
},
@@ -4069,12 +4099,12 @@ func TestSwitchHoldForward(t *testing.T) {
// Forward a fresh packet. It is expected to be failed immediately,
// because there is no interceptor registered.
require.NoError(t, switchForwardInterceptor.ForwardPackets(
linkQuit, false, createTestPacket(),
linkQuit, false, c.createTestPacket(),
))
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
// Forward a replayed packet. It is expected to be held until the
// interceptor connects. To continue the test, it needs to be ran in a
@@ -4082,48 +4112,48 @@ func TestSwitchHoldForward(t *testing.T) {
errChan := make(chan error)
go func() {
errChan <- switchForwardInterceptor.ForwardPackets(
linkQuit, true, createTestPacket(),
linkQuit, true, c.createTestPacket(),
)
}()
// Assert that nothing is forward to the switch.
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertNumCircuits(t, c.s, 0, 0)
// Register an interceptor.
switchForwardInterceptor.SetInterceptor(
forwardInterceptor.InterceptForwardHtlc,
c.forwardInterceptor.InterceptForwardHtlc,
)
// Expect the ForwardPackets call to unblock.
require.NoError(t, <-errChan)
// Now expect the queued packet to come through.
forwardInterceptor.getIntercepted()
c.forwardInterceptor.getIntercepted()
// Disconnect and reconnect interceptor.
switchForwardInterceptor.SetInterceptor(nil)
switchForwardInterceptor.SetInterceptor(
forwardInterceptor.InterceptForwardHtlc,
c.forwardInterceptor.InterceptForwardHtlc,
)
// A replay of the held packet is expected.
intercepted := forwardInterceptor.getIntercepted()
intercepted := c.forwardInterceptor.getIntercepted()
// Settle the packet.
require.NoError(t, switchForwardInterceptor.Resolve(&FwdResolution{
Key: intercepted.IncomingCircuit,
Action: FwdActionSettle,
Preimage: preimage,
Preimage: c.preimage,
}))
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, c.bobChannelLink, false)
assertOutgoingLinkReceive(t, c.aliceChannelLink, true)
assertNumCircuits(t, c.s, 0, 0)
require.NoError(t, switchForwardInterceptor.Stop())
select {
case <-forwardInterceptor.interceptedChan:
case <-c.forwardInterceptor.interceptedChan:
require.Fail(t, "unexpected interception")
default: