diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index e74bfdec2..6af268d56 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -690,9 +690,11 @@ func (f *mockChannelLink) completeCircuit(pkt *htlcPacket) error { f.htlcID++ case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC: - err := f.htlcSwitch.teardownCircuit(pkt) - if err != nil { - return err + if pkt.circuit != nil { + err := f.htlcSwitch.teardownCircuit(pkt) + if err != nil { + return err + } } } diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index b50f8dbe6..9c10cca44 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3257,16 +3257,33 @@ func TestSwitchHoldForward(t *testing.T) { preimage := [sha256.Size]byte{1} rhash := sha256.Sum256(preimage[:]) onionBlob := [1366]byte{4, 5, 6} - ogPacket := &htlcPacket{ - incomingChanID: aliceChannelLink.ShortChanID(), - incomingHTLCID: 0, - outgoingChanID: bobChannelLink.ShortChanID(), - obfuscator: NewMockObfuscator(), - htlc: &lnwire.UpdateAddHTLC{ - PaymentHash: rhash, - Amount: 1, - OnionBlob: onionBlob, - }, + incomingHtlcID := uint64(0) + + createTestPacket := func() *htlcPacket { + incomingHtlcID++ + + return &htlcPacket{ + incomingChanID: aliceChannelLink.ShortChanID(), + incomingHTLCID: incomingHtlcID, + outgoingChanID: bobChannelLink.ShortChanID(), + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + OnionBlob: onionBlob, + }, + } + } + + createSettlePacket := func(outgoingHTLCID uint64) *htlcPacket { + return &htlcPacket{ + outgoingChanID: bobChannelLink.ShortChanID(), + outgoingHTLCID: outgoingHTLCID, + amount: 1, + htlc: &lnwire.UpdateFulfillHTLC{ + PaymentPreimage: preimage, + }, + } } forwardInterceptor := &mockForwardInterceptor{ @@ -3281,7 +3298,9 @@ func TestSwitchHoldForward(t *testing.T) { // Test resume a hold forward. assertNumCircuits(t, s, 0, 0) - err = switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket) + err = switchForwardInterceptor.ForwardPackets( + linkQuit, false, createTestPacket(), + ) require.NoError(t, err) assertNumCircuits(t, s, 0, 0) @@ -3291,19 +3310,14 @@ func TestSwitchHoldForward(t *testing.T) { Action: FwdActionResume, Key: forwardInterceptor.getIntercepted().IncomingCircuit, })) - assertOutgoingLinkReceive(t, bobChannelLink, true) + receivedPkt := assertOutgoingLinkReceive(t, bobChannelLink, true) assertNumCircuits(t, s, 1, 1) // settling the htlc to close the circuit. - settle := &htlcPacket{ - outgoingChanID: bobChannelLink.ShortChanID(), - outgoingHTLCID: 0, - amount: 1, - htlc: &lnwire.UpdateFulfillHTLC{ - PaymentPreimage: preimage, - }, - } - err = switchForwardInterceptor.ForwardPackets(linkQuit, false, settle) + err = switchForwardInterceptor.ForwardPackets( + linkQuit, false, + createSettlePacket(receivedPkt.outgoingHTLCID), + ) require.NoError(t, err) assertOutgoingLinkReceive(t, aliceChannelLink, true) @@ -3311,7 +3325,7 @@ func TestSwitchHoldForward(t *testing.T) { // Test resume a hold forward after disconnection. require.NoError(t, switchForwardInterceptor.ForwardPackets( - linkQuit, false, ogPacket, + linkQuit, false, createTestPacket(), )) // Wait until the packet is offered to the interceptor. @@ -3324,13 +3338,13 @@ func TestSwitchHoldForward(t *testing.T) { // Disconnect should resume the forwarding. switchForwardInterceptor.SetInterceptor(nil) - assertOutgoingLinkReceive(t, bobChannelLink, true) + receivedPkt = assertOutgoingLinkReceive(t, bobChannelLink, true) assertNumCircuits(t, s, 1, 1) // Settle the htlc to close the circuit. - settle.outgoingHTLCID = 1 require.NoError(t, switchForwardInterceptor.ForwardPackets( - linkQuit, false, settle, + linkQuit, false, + createSettlePacket(receivedPkt.outgoingHTLCID), )) assertOutgoingLinkReceive(t, aliceChannelLink, true) @@ -3342,7 +3356,7 @@ func TestSwitchHoldForward(t *testing.T) { ) require.NoError(t, switchForwardInterceptor.ForwardPackets( - linkQuit, false, ogPacket, + linkQuit, false, createTestPacket(), )) assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3358,7 +3372,9 @@ func TestSwitchHoldForward(t *testing.T) { // Test failing a hold forward with a failure message. require.NoError(t, - switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket), + switchForwardInterceptor.ForwardPackets( + linkQuit, false, createTestPacket(), + ), ) assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3378,7 +3394,9 @@ func TestSwitchHoldForward(t *testing.T) { assertNumCircuits(t, s, 0, 0) // Test failing a hold forward with a malformed htlc failure. - err = switchForwardInterceptor.ForwardPackets(linkQuit, false, ogPacket) + err = switchForwardInterceptor.ForwardPackets( + linkQuit, false, createTestPacket(), + ) require.NoError(t, err) assertNumCircuits(t, s, 0, 0) @@ -3408,7 +3426,7 @@ func TestSwitchHoldForward(t *testing.T) { // Test settling a hold forward require.NoError(t, switchForwardInterceptor.ForwardPackets( - linkQuit, false, ogPacket, + linkQuit, false, createTestPacket(), )) assertNumCircuits(t, s, 0, 0) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3431,7 +3449,7 @@ func TestSwitchHoldForward(t *testing.T) { // Forward a fresh packet. It is expected to be failed immediately, // because there is no interceptor registered. require.NoError(t, switchForwardInterceptor.ForwardPackets( - linkQuit, false, ogPacket, + linkQuit, false, createTestPacket(), )) assertOutgoingLinkReceive(t, bobChannelLink, false) @@ -3444,7 +3462,7 @@ func TestSwitchHoldForward(t *testing.T) { errChan := make(chan error) go func() { errChan <- switchForwardInterceptor.ForwardPackets( - linkQuit, true, ogPacket, + linkQuit, true, createTestPacket(), ) }()