multi: move payment state handling into MPPayment

This commit moves the struct `paymentState` used in `routing` into
`channeldb` and replaces it with `MPPaymentState`. In the following
commit we'd see the benefit, that we don't need to pass variables back
and forth between the two packages. More importantly, this state is put
closer to its origin, and is strictly updated whenever a payment is read
from disk. This approach is less error-prone comparing to the previous
one, which both the `payment` and `paymentState` need to be updated at
the same time to make sure the data stay consistant in a parallel
environment.
This commit is contained in:
yyforyongyu
2023-02-08 03:10:20 +08:00
committed by Olaoluwa Osuntokun
parent bf99e42f8e
commit 52c00e8cc4
7 changed files with 281 additions and 407 deletions

View File

@ -791,314 +791,6 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
}
}
// TestPaymentState tests that the logics implemented on paymentState struct
// are as expected. In particular, that the method terminated and
// needWaitForShards return the right values.
func TestPaymentState(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
// Use the following three params, each is equivalent to a bool
// statement, to construct 8 test cases so that we can
// exhaustively catch all possible states.
numAttemptsInFlight int
remainingAmt lnwire.MilliSatoshi
terminate bool
expectedTerminated bool
expectedNeedWaitForShards bool
}{
{
// If we have active shards and terminate is marked
// false, the state is not terminated. Since the
// remaining amount is zero, we need to wait for shards
// to be finished and launch no more shards.
name: "state 100",
numAttemptsInFlight: 1,
remainingAmt: lnwire.MilliSatoshi(0),
terminate: false,
expectedTerminated: false,
expectedNeedWaitForShards: true,
},
{
// If we have active shards while terminate is marked
// true, the state is not terminated, and we need to
// wait for shards to be finished and launch no more
// shards.
name: "state 101",
numAttemptsInFlight: 1,
remainingAmt: lnwire.MilliSatoshi(0),
terminate: true,
expectedTerminated: false,
expectedNeedWaitForShards: true,
},
{
// If we have active shards and terminate is marked
// false, the state is not terminated. Since the
// remaining amount is not zero, we don't need to wait
// for shards outcomes and should launch more shards.
name: "state 110",
numAttemptsInFlight: 1,
remainingAmt: lnwire.MilliSatoshi(1),
terminate: false,
expectedTerminated: false,
expectedNeedWaitForShards: false,
},
{
// If we have active shards and terminate is marked
// true, the state is not terminated. Even the
// remaining amount is not zero, we need to wait for
// shards outcomes because state is terminated.
name: "state 111",
numAttemptsInFlight: 1,
remainingAmt: lnwire.MilliSatoshi(1),
terminate: true,
expectedTerminated: false,
expectedNeedWaitForShards: true,
},
{
// If we have no active shards while terminate is marked
// false, the state is not terminated, and we don't
// need to wait for more shard outcomes because there
// are no active shards.
name: "state 000",
numAttemptsInFlight: 0,
remainingAmt: lnwire.MilliSatoshi(0),
terminate: false,
expectedTerminated: false,
expectedNeedWaitForShards: false,
},
{
// If we have no active shards while terminate is marked
// true, the state is terminated, and we don't need to
// wait for shards to be finished.
name: "state 001",
numAttemptsInFlight: 0,
remainingAmt: lnwire.MilliSatoshi(0),
terminate: true,
expectedTerminated: true,
expectedNeedWaitForShards: false,
},
{
// If we have no active shards while terminate is marked
// false, the state is not terminated. Since the
// remaining amount is not zero, we don't need to wait
// for shards outcomes and should launch more shards.
name: "state 010",
numAttemptsInFlight: 0,
remainingAmt: lnwire.MilliSatoshi(1),
terminate: false,
expectedTerminated: false,
expectedNeedWaitForShards: false,
},
{
// If we have no active shards while terminate is marked
// true, the state is terminated, and we don't need to
// wait for shards outcomes.
name: "state 011",
numAttemptsInFlight: 0,
remainingAmt: lnwire.MilliSatoshi(1),
terminate: true,
expectedTerminated: true,
expectedNeedWaitForShards: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ps := &paymentState{
numAttemptsInFlight: tc.numAttemptsInFlight,
remainingAmt: tc.remainingAmt,
terminate: tc.terminate,
}
require.Equal(
t, tc.expectedTerminated, ps.terminated(),
"terminated returned wrong value",
)
require.Equal(
t, tc.expectedNeedWaitForShards,
ps.needWaitForShards(),
"needWaitForShards returned wrong value",
)
})
}
}
// TestUpdatePaymentState checks that the method updatePaymentState updates the
// paymentState as expected.
func TestUpdatePaymentState(t *testing.T) {
t.Parallel()
// paymentHash is the identifier on paymentLifecycle.
paymentHash := lntypes.Hash{}
preimage := lntypes.Preimage{}
failureReasonError := channeldb.FailureReasonError
// TODO(yy): make MPPayment into an interface so we can mock it. The
// current design implicitly tests the methods SendAmt, TerminalInfo,
// and InFlightHTLCs on channeldb.MPPayment, which is not good. Once
// MPPayment becomes an interface, we can then mock these methods here.
testCases := []struct {
name string
payment *channeldb.MPPayment
totalAmt int
feeLimit int
expectedState *paymentState
shouldReturnError bool
}{
{
// Test that the error returned from FetchPayment is
// handled properly. We use a nil payment to indicate
// we want to return an error.
name: "fetch payment error",
payment: nil,
shouldReturnError: true,
},
{
// Test that when the sentAmt exceeds totalAmount, the
// error is returned.
name: "amount exceeded error",
// SentAmt returns 90, 10
// TerminalInfo returns non-nil, nil
// InFlightHTLCs returns 0
payment: &channeldb.MPPayment{
HTLCs: []channeldb.HTLCAttempt{
makeSettledAttempt(100, 10, preimage),
},
},
totalAmt: 1,
shouldReturnError: true,
},
{
// Test that when the fee budget is reached, the
// remaining fee should be zero.
name: "fee budget reached",
payment: &channeldb.MPPayment{
// SentAmt returns 90, 10
// TerminalInfo returns nil, nil
// InFlightHTLCs returns 1
HTLCs: []channeldb.HTLCAttempt{
makeActiveAttempt(100, 10),
makeFailedAttempt(100, 10),
},
},
totalAmt: 1000,
feeLimit: 1,
expectedState: &paymentState{
numAttemptsInFlight: 1,
remainingAmt: 1000 - 90,
remainingFees: 0,
terminate: false,
},
},
{
// Test when the payment is settled, the state should
// be marked as terminated.
name: "payment settled",
// SentAmt returns 90, 10
// TerminalInfo returns non-nil, nil
// InFlightHTLCs returns 0
payment: &channeldb.MPPayment{
HTLCs: []channeldb.HTLCAttempt{
makeSettledAttempt(100, 10, preimage),
},
},
totalAmt: 1000,
feeLimit: 100,
expectedState: &paymentState{
numAttemptsInFlight: 0,
remainingAmt: 1000 - 90,
remainingFees: 100 - 10,
terminate: true,
},
},
{
// Test when the payment is failed, the state should be
// marked as terminated.
name: "payment failed",
// SentAmt returns 0, 0
// TerminalInfo returns nil, non-nil
// InFlightHTLCs returns 0
payment: &channeldb.MPPayment{
FailureReason: &failureReasonError,
},
totalAmt: 1000,
feeLimit: 100,
expectedState: &paymentState{
numAttemptsInFlight: 0,
remainingAmt: 1000,
remainingFees: 100,
terminate: true,
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Create mock control tower and assign it to router.
// We will then use the router and the paymentHash
// above to create our paymentLifecycle for this test.
ct := &mockControlTower{}
rt := &ChannelRouter{cfg: &Config{Control: ct}}
pl := &paymentLifecycle{
router: rt,
identifier: paymentHash,
feeLimit: lnwire.MilliSatoshi(tc.feeLimit),
}
if tc.payment == nil {
// A nil payment indicates we want to test an
// error returned from FetchPayment.
dummyErr := errors.New("dummy")
ct.On("FetchPayment", paymentHash).Return(
nil, dummyErr,
)
} else {
// Attach the payment info.
info := &channeldb.PaymentCreationInfo{
Value: lnwire.MilliSatoshi(tc.totalAmt),
}
tc.payment.Info = info
// Otherwise we will return the payment.
ct.On("FetchPayment", paymentHash).Return(
tc.payment, nil,
)
}
// Call the method that updates the payment state.
_, state, err := pl.fetchPaymentState()
// Assert that the mock method is called as
// intended.
ct.AssertExpectations(t)
if tc.shouldReturnError {
require.Error(t, err, "expect an error")
return
}
require.NoError(t, err, "unexpected error")
require.Equal(
t, tc.expectedState, state,
"state not updated as expected",
)
})
}
}
func makeActiveAttempt(total, fee int) channeldb.HTLCAttempt {
return channeldb.HTLCAttempt{
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
@ -1114,15 +806,6 @@ func makeSettledAttempt(total, fee int,
}
}
func makeFailedAttempt(total, fee int) channeldb.HTLCAttempt {
return channeldb.HTLCAttempt{
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
Failure: &channeldb.HTLCFailInfo{
Reason: channeldb.HTLCFailInternal,
},
}
}
func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo {
hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)}
return channeldb.HTLCAttemptInfo{