diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index b04b04d4b..9df02ec32 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -1809,3 +1809,61 @@ func TestHandleAttemptResultSuccess(t *testing.T) { require.NoError(t, err, "expected no error") require.Equal(t, attempt, attemptResult.attempt) } + +// TestReloadInflightAttemptsLegacy checks that when handling a legacy HTLC +// attempt, `collectResult` behaves as expected. +func TestReloadInflightAttemptsLegacy(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + // Mount the resultCollector to check the full call path. + p.resultCollector = p.collectResultAsync + + // Create testing params. + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Make the attempt.Hash to be nil to mock a legacy payment. + attempt.Hash = nil + + // Create a mock result returned from the switch. + result := &htlcswitch.PaymentResult{ + Preimage: preimage, + } + + // 1. calls `FetchPayment` and return the payment. + m.control.On("FetchPayment", p.identifier).Return(m.payment, nil).Once() + + // 2. calls `InFlightHTLCs` and return the attempt. + attempts := []channeldb.HTLCAttempt{*attempt} + m.payment.On("InFlightHTLCs").Return(attempts).Once() + + // 3. Mock the htlcswitch to return a the result chan. + resultChan := make(chan *htlcswitch.PaymentResult, 1) + m.payer.On("GetAttemptResult", + attempt.AttemptID, p.identifier, mock.Anything, + ).Return(resultChan, nil).Once().Run(func(args mock.Arguments) { + // Send the preimage to the result chan. + resultChan <- result + }) + + // Now call the method under test. + payment, err := p.reloadInflightAttempts() + require.NoError(t, err) + require.Equal(t, m.payment, payment) + + var r *switchResult + + // Assert the result is returned within testTimeout. + waitErr := wait.NoError(func() error { + r = <-p.resultCollected + return nil + }, testTimeout) + require.NoError(t, waitErr, "timeout waiting for result") + + // Assert the result is received as expected. + require.Equal(t, result, r.result) +}