mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-12 04:40:05 +02:00
multi: move payment state handling into MPPayment
This commit moves the struct `paymentState` used in `routing` into `channeldb` and replaces it with `MPPaymentState`. In the following commit we'd see the benefit, that we don't need to pass variables back and forth between the two packages. More importantly, this state is put closer to its origin, and is strictly updated whenever a payment is read from disk. This approach is less error-prone comparing to the previous one, which both the `payment` and `paymentState` need to be updated at the same time to make sure the data stay consistant in a parallel environment.
This commit is contained in:
parent
bf99e42f8e
commit
52c00e8cc4
@ -3,6 +3,7 @@ package channeldb
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"time"
|
||||
@ -147,6 +148,27 @@ type HTLCFailInfo struct {
|
||||
FailureSourceIndex uint32
|
||||
}
|
||||
|
||||
// MPPaymentState wraps a series of info needed for a given payment, which is
|
||||
// used by both MPP and AMP. This is a memory representation of the payment's
|
||||
// current state and is updated whenever the payment is read from disk.
|
||||
type MPPaymentState struct {
|
||||
// NumAttemptsInFlight specifies the number of HTLCs the payment is
|
||||
// waiting results for.
|
||||
NumAttemptsInFlight int
|
||||
|
||||
// RemainingAmt specifies how much more money to be sent.
|
||||
RemainingAmt lnwire.MilliSatoshi
|
||||
|
||||
// FeesPaid specifies the total fees paid so far that can be used to
|
||||
// calculate remaining fee budget.
|
||||
FeesPaid lnwire.MilliSatoshi
|
||||
|
||||
// Terminate indicates the payment is in its final stage and no more
|
||||
// shards should be launched. This value is true if we have an HTLC
|
||||
// settled or the payment has an error.
|
||||
Terminate bool
|
||||
}
|
||||
|
||||
// MPPayment is a wrapper around a payment's PaymentCreationInfo and
|
||||
// HTLCAttempts. All payments will have the PaymentCreationInfo set, any
|
||||
// HTLCs made in attempts to be completed will populated in the HTLCs slice.
|
||||
@ -175,6 +197,11 @@ type MPPayment struct {
|
||||
|
||||
// Status is the current PaymentStatus of this payment.
|
||||
Status PaymentStatus
|
||||
|
||||
// State is the current state of the payment that holds a number of key
|
||||
// insights and is used to determine what to do on each payment loop
|
||||
// iteration.
|
||||
State *MPPaymentState
|
||||
}
|
||||
|
||||
// Terminated returns a bool to specify whether the payment is in a terminal
|
||||
@ -280,6 +307,55 @@ func (m *MPPayment) Registrable() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setState creates and attaches a new MPPaymentState to the payment. It also
|
||||
// updates the payment's status based on its current state.
|
||||
func (m *MPPayment) setState() error {
|
||||
// Fetch the total amount and fees that has already been sent in
|
||||
// settled and still in-flight shards.
|
||||
sentAmt, fees := m.SentAmt()
|
||||
|
||||
// Sanity check we haven't sent a value larger than the payment amount.
|
||||
totalAmt := m.Info.Value
|
||||
if sentAmt > totalAmt {
|
||||
return fmt.Errorf("%w: sent=%v, total=%v", ErrSentExceedsTotal,
|
||||
sentAmt, totalAmt)
|
||||
}
|
||||
|
||||
// Get any terminal info for this payment.
|
||||
settle, failure := m.TerminalInfo()
|
||||
|
||||
// If either an HTLC settled, or the payment has a payment level
|
||||
// failure recorded, it means we should terminate the moment all shards
|
||||
// have returned with a result.
|
||||
terminate := settle != nil || failure != nil
|
||||
|
||||
// Now determine the payment's status.
|
||||
status, err := decidePaymentStatus(m.HTLCs, m.FailureReason)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the payment state and status.
|
||||
m.State = &MPPaymentState{
|
||||
NumAttemptsInFlight: len(m.InFlightHTLCs()),
|
||||
RemainingAmt: totalAmt - sentAmt,
|
||||
FeesPaid: fees,
|
||||
Terminate: terminate,
|
||||
}
|
||||
m.Status = status
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetState calls the internal method setState. This is a temporary method
|
||||
// to be used by the tests in routing. Once the tests are updated to use mocks,
|
||||
// this method can be removed.
|
||||
//
|
||||
// TODO(yy): delete.
|
||||
func (m *MPPayment) SetState() error {
|
||||
return m.setState()
|
||||
}
|
||||
|
||||
// serializeHTLCSettleInfo serializes the details of a settled htlc.
|
||||
func serializeHTLCSettleInfo(w io.Writer, s *HTLCSettleInfo) error {
|
||||
if _, err := w.Write(s.Preimage[:]); err != nil {
|
||||
|
@ -5,6 +5,9 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -119,3 +122,153 @@ func TestRegistrable(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaymentSetState checks that the method setState creates the
|
||||
// MPPaymentState as expected.
|
||||
func TestPaymentSetState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test preimage and failure reason.
|
||||
preimage := lntypes.Preimage{1}
|
||||
failureReasonError := FailureReasonError
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
payment *MPPayment
|
||||
totalAmt int
|
||||
|
||||
expectedState *MPPaymentState
|
||||
errExpected error
|
||||
}{
|
||||
{
|
||||
// Test that when the sentAmt exceeds totalAmount, the
|
||||
// error is returned.
|
||||
name: "amount exceeded error",
|
||||
// SentAmt returns 90, 10
|
||||
// TerminalInfo returns non-nil, nil
|
||||
// InFlightHTLCs returns 0
|
||||
payment: &MPPayment{
|
||||
HTLCs: []HTLCAttempt{
|
||||
makeSettledAttempt(100, 10, preimage),
|
||||
},
|
||||
},
|
||||
totalAmt: 1,
|
||||
errExpected: ErrSentExceedsTotal,
|
||||
},
|
||||
{
|
||||
// Test that when the htlc is failed, the fee is not
|
||||
// used.
|
||||
name: "fee excluded for failed htlc",
|
||||
payment: &MPPayment{
|
||||
// SentAmt returns 90, 10
|
||||
// TerminalInfo returns nil, nil
|
||||
// InFlightHTLCs returns 1
|
||||
HTLCs: []HTLCAttempt{
|
||||
makeActiveAttempt(100, 10),
|
||||
makeFailedAttempt(100, 10),
|
||||
},
|
||||
},
|
||||
totalAmt: 1000,
|
||||
expectedState: &MPPaymentState{
|
||||
NumAttemptsInFlight: 1,
|
||||
RemainingAmt: 1000 - 90,
|
||||
FeesPaid: 10,
|
||||
Terminate: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Test when the payment is settled, the state should
|
||||
// be marked as terminated.
|
||||
name: "payment settled",
|
||||
// SentAmt returns 90, 10
|
||||
// TerminalInfo returns non-nil, nil
|
||||
// InFlightHTLCs returns 0
|
||||
payment: &MPPayment{
|
||||
HTLCs: []HTLCAttempt{
|
||||
makeSettledAttempt(100, 10, preimage),
|
||||
},
|
||||
},
|
||||
totalAmt: 1000,
|
||||
expectedState: &MPPaymentState{
|
||||
NumAttemptsInFlight: 0,
|
||||
RemainingAmt: 1000 - 90,
|
||||
FeesPaid: 10,
|
||||
Terminate: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Test when the payment is failed, the state should be
|
||||
// marked as terminated.
|
||||
name: "payment failed",
|
||||
// SentAmt returns 0, 0
|
||||
// TerminalInfo returns nil, non-nil
|
||||
// InFlightHTLCs returns 0
|
||||
payment: &MPPayment{
|
||||
FailureReason: &failureReasonError,
|
||||
},
|
||||
totalAmt: 1000,
|
||||
expectedState: &MPPaymentState{
|
||||
NumAttemptsInFlight: 0,
|
||||
RemainingAmt: 1000,
|
||||
FeesPaid: 0,
|
||||
Terminate: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Attach the payment info.
|
||||
info := &PaymentCreationInfo{
|
||||
Value: lnwire.MilliSatoshi(tc.totalAmt),
|
||||
}
|
||||
tc.payment.Info = info
|
||||
|
||||
// Call the method that updates the payment state.
|
||||
err := tc.payment.setState()
|
||||
require.ErrorIs(t, err, tc.errExpected)
|
||||
|
||||
require.Equal(
|
||||
t, tc.expectedState, tc.payment.State,
|
||||
"state not updated as expected",
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func makeActiveAttempt(total, fee int) HTLCAttempt {
|
||||
return HTLCAttempt{
|
||||
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
|
||||
}
|
||||
}
|
||||
|
||||
func makeSettledAttempt(total, fee int,
|
||||
preimage lntypes.Preimage) HTLCAttempt {
|
||||
|
||||
return HTLCAttempt{
|
||||
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
|
||||
Settle: &HTLCSettleInfo{Preimage: preimage},
|
||||
}
|
||||
}
|
||||
|
||||
func makeFailedAttempt(total, fee int) HTLCAttempt {
|
||||
return HTLCAttempt{
|
||||
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
|
||||
Failure: &HTLCFailInfo{
|
||||
Reason: HTLCFailInternal,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeAttemptInfo(total, amtForwarded int) HTLCAttemptInfo {
|
||||
hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)}
|
||||
return HTLCAttemptInfo{
|
||||
Route: route.Route{
|
||||
TotalAmount: lnwire.MilliSatoshi(total),
|
||||
Hops: []*route.Hop{hop},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -93,6 +93,10 @@ var (
|
||||
// to a payment that already has a failure reason.
|
||||
ErrPaymentPendingFailed = errors.New("payment has failure reason")
|
||||
|
||||
// ErrSentExceedsTotal is returned if the payment's current total sent
|
||||
// amount exceed the total amount.
|
||||
ErrSentExceedsTotal = errors.New("total sent exceeds total amount")
|
||||
|
||||
// errNoAttemptInfo is returned when no attempt info is stored yet.
|
||||
errNoAttemptInfo = errors.New("unable to find attempt info for " +
|
||||
"inflight payment")
|
||||
|
@ -298,19 +298,20 @@ func fetchPayment(bucket kvdb.RBucket) (*MPPayment, error) {
|
||||
failureReason = &reason
|
||||
}
|
||||
|
||||
// Now determine the payment's status.
|
||||
paymentStatus, err := decidePaymentStatus(htlcs, failureReason)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MPPayment{
|
||||
// Create a new payment.
|
||||
payment := &MPPayment{
|
||||
SequenceNum: sequenceNum,
|
||||
Info: creationInfo,
|
||||
HTLCs: htlcs,
|
||||
FailureReason: failureReason,
|
||||
Status: paymentStatus,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Set its state and status.
|
||||
if err := payment.setState(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return payment, nil
|
||||
}
|
||||
|
||||
// fetchHtlcAttempts retrieves all htlc attempts made for the payment found in
|
||||
|
@ -518,6 +518,11 @@ func (m *mockControlTowerOld) fetchPayment(phash lntypes.Hash) (
|
||||
|
||||
// Return a copy of the current attempts.
|
||||
mp.HTLCs = append(mp.HTLCs, p.attempts...)
|
||||
|
||||
if err := mp.SetState(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mp, nil
|
||||
}
|
||||
|
||||
|
@ -31,97 +31,44 @@ type paymentLifecycle struct {
|
||||
currentHeight int32
|
||||
}
|
||||
|
||||
// paymentState holds a number of key insights learned from a given MPPayment
|
||||
// that we use to determine what to do on each payment loop iteration.
|
||||
type paymentState struct {
|
||||
// numAttemptsInFlight specifies the number of HTLCs the payment is
|
||||
// waiting results for.
|
||||
numAttemptsInFlight int
|
||||
|
||||
// remainingAmt specifies how much more money to be sent.
|
||||
remainingAmt lnwire.MilliSatoshi
|
||||
|
||||
// remainingFees specifies the remaining budget that can be used as
|
||||
// fees.
|
||||
remainingFees lnwire.MilliSatoshi
|
||||
|
||||
// terminate indicates the payment is in its final stage and no more
|
||||
// shards should be launched. This value is true if we have an HTLC
|
||||
// settled or the payment has an error.
|
||||
terminate bool
|
||||
}
|
||||
|
||||
// terminated returns a bool to indicate there are no further actions needed
|
||||
// and we should return what we have, either the payment preimage or the
|
||||
// payment error.
|
||||
func (ps paymentState) terminated() bool {
|
||||
func terminated(ps *channeldb.MPPaymentState) bool {
|
||||
// If the payment is in final stage and we have no in flight shards to
|
||||
// wait result for, we consider the whole action terminated.
|
||||
return ps.terminate && ps.numAttemptsInFlight == 0
|
||||
return ps.Terminate && ps.NumAttemptsInFlight == 0
|
||||
}
|
||||
|
||||
// needWaitForShards returns a bool to specify whether we need to wait for the
|
||||
// outcome of the shardHandler.
|
||||
func (ps paymentState) needWaitForShards() bool {
|
||||
func needWaitForShards(ps *channeldb.MPPaymentState) bool {
|
||||
// If we have in flight shards and the payment is in final stage, we
|
||||
// need to wait for the outcomes from the shards. Or if we have no more
|
||||
// money to be sent, we need to wait for the already launched shards.
|
||||
if ps.numAttemptsInFlight == 0 {
|
||||
if ps.NumAttemptsInFlight == 0 {
|
||||
return false
|
||||
}
|
||||
return ps.terminate || ps.remainingAmt == 0
|
||||
return ps.Terminate || ps.RemainingAmt == 0
|
||||
}
|
||||
|
||||
// fetchPaymentState will query the db for the latest payment state information
|
||||
// we need to act on every iteration of the payment loop and update the
|
||||
// paymentState.
|
||||
func (p *paymentLifecycle) fetchPaymentState() (*channeldb.MPPayment,
|
||||
*paymentState, error) {
|
||||
// calcFeeBudget returns the available fee to be used for sending HTLC
|
||||
// attempts.
|
||||
func (p *paymentLifecycle) calcFeeBudget(
|
||||
feesPaid lnwire.MilliSatoshi) lnwire.MilliSatoshi {
|
||||
|
||||
// Fetch the latest payment from db.
|
||||
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
budget := p.feeLimit
|
||||
|
||||
// Fetch the total amount and fees that has already been sent in
|
||||
// settled and still in-flight shards.
|
||||
sentAmt, fees := payment.SentAmt()
|
||||
|
||||
// Sanity check we haven't sent a value larger than the payment amount.
|
||||
totalAmt := payment.Info.Value
|
||||
if sentAmt > totalAmt {
|
||||
return nil, nil, fmt.Errorf("amount sent %v exceeds total "+
|
||||
"amount %v", sentAmt, totalAmt)
|
||||
}
|
||||
|
||||
// We'll subtract the used fee from our fee budget, but allow the fees
|
||||
// of the already sent shards to exceed our budget (can happen after
|
||||
// restarts).
|
||||
feeBudget := p.feeLimit
|
||||
if fees <= feeBudget {
|
||||
feeBudget -= fees
|
||||
// We'll subtract the used fee from our fee budget. In case of
|
||||
// overflow, we need to check whether feesPaid exceeds our budget
|
||||
// already.
|
||||
if feesPaid <= budget {
|
||||
budget -= feesPaid
|
||||
} else {
|
||||
feeBudget = 0
|
||||
budget = 0
|
||||
}
|
||||
|
||||
// Get any terminal info for this payment.
|
||||
settle, failure := payment.TerminalInfo()
|
||||
|
||||
// If either an HTLC settled, or the payment has a payment level
|
||||
// failure recorded, it means we should terminate the moment all shards
|
||||
// have returned with a result.
|
||||
terminate := settle != nil || failure != nil
|
||||
|
||||
// Update the payment state.
|
||||
state := &paymentState{
|
||||
numAttemptsInFlight: len(payment.InFlightHTLCs()),
|
||||
remainingAmt: totalAmt - sentAmt,
|
||||
remainingFees: feeBudget,
|
||||
terminate: terminate,
|
||||
}
|
||||
|
||||
return payment, state, nil
|
||||
return budget
|
||||
}
|
||||
|
||||
// resumePayment resumes the paymentLifecycle from the current state.
|
||||
@ -143,7 +90,8 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
|
||||
// If we had any existing attempts outstanding, we'll start by spinning
|
||||
// up goroutines that'll collect their results and deliver them to the
|
||||
// lifecycle loop below.
|
||||
payment, _, err := p.fetchPaymentState()
|
||||
// Fetch the latest payment from db.
|
||||
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
|
||||
if err != nil {
|
||||
return [32]byte{}, nil, err
|
||||
}
|
||||
@ -172,23 +120,27 @@ lifecycle:
|
||||
// collectResultAsync), it is NOT guaranteed that we always
|
||||
// have the latest state here. This is fine as long as the
|
||||
// state is consistent as a whole.
|
||||
payment, ps, err := p.fetchPaymentState()
|
||||
|
||||
// Fetch the latest payment from db.
|
||||
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
|
||||
if err != nil {
|
||||
return [32]byte{}, nil, err
|
||||
}
|
||||
|
||||
ps := payment.State
|
||||
remainingFees := p.calcFeeBudget(ps.FeesPaid)
|
||||
|
||||
log.Debugf("Payment %v in state terminate=%v, "+
|
||||
"active_shards=%v, rem_value=%v, fee_limit=%v",
|
||||
p.identifier, ps.terminate, ps.numAttemptsInFlight,
|
||||
ps.remainingAmt, ps.remainingFees,
|
||||
)
|
||||
p.identifier, ps.Terminate, ps.NumAttemptsInFlight,
|
||||
ps.RemainingAmt, remainingFees)
|
||||
|
||||
// TODO(yy): sanity check all the states to make sure
|
||||
// everything is expected.
|
||||
switch {
|
||||
// We have a terminal condition and no active shards, we are
|
||||
// ready to exit.
|
||||
case ps.terminated():
|
||||
case payment.Terminated():
|
||||
// Find the first successful shard and return
|
||||
// the preimage and route.
|
||||
for _, a := range payment.HTLCs {
|
||||
@ -215,7 +167,7 @@ lifecycle:
|
||||
// If we either reached a terminal error condition (but had
|
||||
// active shards still) or there is no remaining value to send,
|
||||
// we'll wait for a shard outcome.
|
||||
case ps.needWaitForShards():
|
||||
case needWaitForShards(ps):
|
||||
// We still have outstanding shards, so wait for a new
|
||||
// outcome to be available before re-evaluating our
|
||||
// state.
|
||||
@ -257,8 +209,8 @@ lifecycle:
|
||||
|
||||
// Create a new payment attempt from the given payment session.
|
||||
rt, err := p.paySession.RequestRoute(
|
||||
ps.remainingAmt, ps.remainingFees,
|
||||
uint32(ps.numAttemptsInFlight),
|
||||
ps.RemainingAmt, remainingFees,
|
||||
uint32(ps.NumAttemptsInFlight),
|
||||
uint32(p.currentHeight),
|
||||
)
|
||||
if err != nil {
|
||||
@ -273,7 +225,7 @@ lifecycle:
|
||||
// There is no route to try, and we have no active
|
||||
// shards. This means that there is no way for us to
|
||||
// send the payment, so mark it failed with no route.
|
||||
if ps.numAttemptsInFlight == 0 {
|
||||
if ps.NumAttemptsInFlight == 0 {
|
||||
failureCode := routeErr.FailureReason()
|
||||
log.Debugf("Marking payment %v permanently "+
|
||||
"failed with no route: %v",
|
||||
@ -301,7 +253,7 @@ lifecycle:
|
||||
|
||||
// If this route will consume the last remaining amount to send
|
||||
// to the receiver, this will be our last shard (for now).
|
||||
lastShard := rt.ReceiverAmt() == ps.remainingAmt
|
||||
lastShard := rt.ReceiverAmt() == ps.RemainingAmt
|
||||
|
||||
// We found a route to try, launch a new shard.
|
||||
attempt, outcome, err := shardHandler.launchShard(rt, lastShard)
|
||||
|
@ -791,314 +791,6 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaymentState tests that the logics implemented on paymentState struct
|
||||
// are as expected. In particular, that the method terminated and
|
||||
// needWaitForShards return the right values.
|
||||
func TestPaymentState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
// Use the following three params, each is equivalent to a bool
|
||||
// statement, to construct 8 test cases so that we can
|
||||
// exhaustively catch all possible states.
|
||||
numAttemptsInFlight int
|
||||
remainingAmt lnwire.MilliSatoshi
|
||||
terminate bool
|
||||
|
||||
expectedTerminated bool
|
||||
expectedNeedWaitForShards bool
|
||||
}{
|
||||
{
|
||||
// If we have active shards and terminate is marked
|
||||
// false, the state is not terminated. Since the
|
||||
// remaining amount is zero, we need to wait for shards
|
||||
// to be finished and launch no more shards.
|
||||
name: "state 100",
|
||||
numAttemptsInFlight: 1,
|
||||
remainingAmt: lnwire.MilliSatoshi(0),
|
||||
terminate: false,
|
||||
expectedTerminated: false,
|
||||
expectedNeedWaitForShards: true,
|
||||
},
|
||||
{
|
||||
// If we have active shards while terminate is marked
|
||||
// true, the state is not terminated, and we need to
|
||||
// wait for shards to be finished and launch no more
|
||||
// shards.
|
||||
name: "state 101",
|
||||
numAttemptsInFlight: 1,
|
||||
remainingAmt: lnwire.MilliSatoshi(0),
|
||||
terminate: true,
|
||||
expectedTerminated: false,
|
||||
expectedNeedWaitForShards: true,
|
||||
},
|
||||
|
||||
{
|
||||
// If we have active shards and terminate is marked
|
||||
// false, the state is not terminated. Since the
|
||||
// remaining amount is not zero, we don't need to wait
|
||||
// for shards outcomes and should launch more shards.
|
||||
name: "state 110",
|
||||
numAttemptsInFlight: 1,
|
||||
remainingAmt: lnwire.MilliSatoshi(1),
|
||||
terminate: false,
|
||||
expectedTerminated: false,
|
||||
expectedNeedWaitForShards: false,
|
||||
},
|
||||
{
|
||||
// If we have active shards and terminate is marked
|
||||
// true, the state is not terminated. Even the
|
||||
// remaining amount is not zero, we need to wait for
|
||||
// shards outcomes because state is terminated.
|
||||
name: "state 111",
|
||||
numAttemptsInFlight: 1,
|
||||
remainingAmt: lnwire.MilliSatoshi(1),
|
||||
terminate: true,
|
||||
expectedTerminated: false,
|
||||
expectedNeedWaitForShards: true,
|
||||
},
|
||||
{
|
||||
// If we have no active shards while terminate is marked
|
||||
// false, the state is not terminated, and we don't
|
||||
// need to wait for more shard outcomes because there
|
||||
// are no active shards.
|
||||
name: "state 000",
|
||||
numAttemptsInFlight: 0,
|
||||
remainingAmt: lnwire.MilliSatoshi(0),
|
||||
terminate: false,
|
||||
expectedTerminated: false,
|
||||
expectedNeedWaitForShards: false,
|
||||
},
|
||||
{
|
||||
// If we have no active shards while terminate is marked
|
||||
// true, the state is terminated, and we don't need to
|
||||
// wait for shards to be finished.
|
||||
name: "state 001",
|
||||
numAttemptsInFlight: 0,
|
||||
remainingAmt: lnwire.MilliSatoshi(0),
|
||||
terminate: true,
|
||||
expectedTerminated: true,
|
||||
expectedNeedWaitForShards: false,
|
||||
},
|
||||
{
|
||||
// If we have no active shards while terminate is marked
|
||||
// false, the state is not terminated. Since the
|
||||
// remaining amount is not zero, we don't need to wait
|
||||
// for shards outcomes and should launch more shards.
|
||||
name: "state 010",
|
||||
numAttemptsInFlight: 0,
|
||||
remainingAmt: lnwire.MilliSatoshi(1),
|
||||
terminate: false,
|
||||
expectedTerminated: false,
|
||||
expectedNeedWaitForShards: false,
|
||||
},
|
||||
{
|
||||
// If we have no active shards while terminate is marked
|
||||
// true, the state is terminated, and we don't need to
|
||||
// wait for shards outcomes.
|
||||
name: "state 011",
|
||||
numAttemptsInFlight: 0,
|
||||
remainingAmt: lnwire.MilliSatoshi(1),
|
||||
terminate: true,
|
||||
expectedTerminated: true,
|
||||
expectedNeedWaitForShards: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := &paymentState{
|
||||
numAttemptsInFlight: tc.numAttemptsInFlight,
|
||||
remainingAmt: tc.remainingAmt,
|
||||
terminate: tc.terminate,
|
||||
}
|
||||
|
||||
require.Equal(
|
||||
t, tc.expectedTerminated, ps.terminated(),
|
||||
"terminated returned wrong value",
|
||||
)
|
||||
require.Equal(
|
||||
t, tc.expectedNeedWaitForShards,
|
||||
ps.needWaitForShards(),
|
||||
"needWaitForShards returned wrong value",
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdatePaymentState checks that the method updatePaymentState updates the
|
||||
// paymentState as expected.
|
||||
func TestUpdatePaymentState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// paymentHash is the identifier on paymentLifecycle.
|
||||
paymentHash := lntypes.Hash{}
|
||||
preimage := lntypes.Preimage{}
|
||||
failureReasonError := channeldb.FailureReasonError
|
||||
|
||||
// TODO(yy): make MPPayment into an interface so we can mock it. The
|
||||
// current design implicitly tests the methods SendAmt, TerminalInfo,
|
||||
// and InFlightHTLCs on channeldb.MPPayment, which is not good. Once
|
||||
// MPPayment becomes an interface, we can then mock these methods here.
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
payment *channeldb.MPPayment
|
||||
totalAmt int
|
||||
feeLimit int
|
||||
|
||||
expectedState *paymentState
|
||||
shouldReturnError bool
|
||||
}{
|
||||
{
|
||||
// Test that the error returned from FetchPayment is
|
||||
// handled properly. We use a nil payment to indicate
|
||||
// we want to return an error.
|
||||
name: "fetch payment error",
|
||||
payment: nil,
|
||||
shouldReturnError: true,
|
||||
},
|
||||
{
|
||||
// Test that when the sentAmt exceeds totalAmount, the
|
||||
// error is returned.
|
||||
name: "amount exceeded error",
|
||||
// SentAmt returns 90, 10
|
||||
// TerminalInfo returns non-nil, nil
|
||||
// InFlightHTLCs returns 0
|
||||
payment: &channeldb.MPPayment{
|
||||
HTLCs: []channeldb.HTLCAttempt{
|
||||
makeSettledAttempt(100, 10, preimage),
|
||||
},
|
||||
},
|
||||
totalAmt: 1,
|
||||
shouldReturnError: true,
|
||||
},
|
||||
{
|
||||
// Test that when the fee budget is reached, the
|
||||
// remaining fee should be zero.
|
||||
name: "fee budget reached",
|
||||
payment: &channeldb.MPPayment{
|
||||
// SentAmt returns 90, 10
|
||||
// TerminalInfo returns nil, nil
|
||||
// InFlightHTLCs returns 1
|
||||
HTLCs: []channeldb.HTLCAttempt{
|
||||
makeActiveAttempt(100, 10),
|
||||
makeFailedAttempt(100, 10),
|
||||
},
|
||||
},
|
||||
totalAmt: 1000,
|
||||
feeLimit: 1,
|
||||
expectedState: &paymentState{
|
||||
numAttemptsInFlight: 1,
|
||||
remainingAmt: 1000 - 90,
|
||||
remainingFees: 0,
|
||||
terminate: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Test when the payment is settled, the state should
|
||||
// be marked as terminated.
|
||||
name: "payment settled",
|
||||
// SentAmt returns 90, 10
|
||||
// TerminalInfo returns non-nil, nil
|
||||
// InFlightHTLCs returns 0
|
||||
payment: &channeldb.MPPayment{
|
||||
HTLCs: []channeldb.HTLCAttempt{
|
||||
makeSettledAttempt(100, 10, preimage),
|
||||
},
|
||||
},
|
||||
totalAmt: 1000,
|
||||
feeLimit: 100,
|
||||
expectedState: &paymentState{
|
||||
numAttemptsInFlight: 0,
|
||||
remainingAmt: 1000 - 90,
|
||||
remainingFees: 100 - 10,
|
||||
terminate: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Test when the payment is failed, the state should be
|
||||
// marked as terminated.
|
||||
name: "payment failed",
|
||||
// SentAmt returns 0, 0
|
||||
// TerminalInfo returns nil, non-nil
|
||||
// InFlightHTLCs returns 0
|
||||
payment: &channeldb.MPPayment{
|
||||
FailureReason: &failureReasonError,
|
||||
},
|
||||
totalAmt: 1000,
|
||||
feeLimit: 100,
|
||||
expectedState: &paymentState{
|
||||
numAttemptsInFlight: 0,
|
||||
remainingAmt: 1000,
|
||||
remainingFees: 100,
|
||||
terminate: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create mock control tower and assign it to router.
|
||||
// We will then use the router and the paymentHash
|
||||
// above to create our paymentLifecycle for this test.
|
||||
ct := &mockControlTower{}
|
||||
rt := &ChannelRouter{cfg: &Config{Control: ct}}
|
||||
pl := &paymentLifecycle{
|
||||
router: rt,
|
||||
identifier: paymentHash,
|
||||
feeLimit: lnwire.MilliSatoshi(tc.feeLimit),
|
||||
}
|
||||
|
||||
if tc.payment == nil {
|
||||
// A nil payment indicates we want to test an
|
||||
// error returned from FetchPayment.
|
||||
dummyErr := errors.New("dummy")
|
||||
ct.On("FetchPayment", paymentHash).Return(
|
||||
nil, dummyErr,
|
||||
)
|
||||
} else {
|
||||
// Attach the payment info.
|
||||
info := &channeldb.PaymentCreationInfo{
|
||||
Value: lnwire.MilliSatoshi(tc.totalAmt),
|
||||
}
|
||||
tc.payment.Info = info
|
||||
|
||||
// Otherwise we will return the payment.
|
||||
ct.On("FetchPayment", paymentHash).Return(
|
||||
tc.payment, nil,
|
||||
)
|
||||
}
|
||||
|
||||
// Call the method that updates the payment state.
|
||||
_, state, err := pl.fetchPaymentState()
|
||||
|
||||
// Assert that the mock method is called as
|
||||
// intended.
|
||||
ct.AssertExpectations(t)
|
||||
|
||||
if tc.shouldReturnError {
|
||||
require.Error(t, err, "expect an error")
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "unexpected error")
|
||||
require.Equal(
|
||||
t, tc.expectedState, state,
|
||||
"state not updated as expected",
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func makeActiveAttempt(total, fee int) channeldb.HTLCAttempt {
|
||||
return channeldb.HTLCAttempt{
|
||||
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
|
||||
@ -1114,15 +806,6 @@ func makeSettledAttempt(total, fee int,
|
||||
}
|
||||
}
|
||||
|
||||
func makeFailedAttempt(total, fee int) channeldb.HTLCAttempt {
|
||||
return channeldb.HTLCAttempt{
|
||||
HTLCAttemptInfo: makeAttemptInfo(total, total-fee),
|
||||
Failure: &channeldb.HTLCFailInfo{
|
||||
Reason: channeldb.HTLCFailInternal,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeAttemptInfo(total, amtForwarded int) channeldb.HTLCAttemptInfo {
|
||||
hop := &route.Hop{AmtToForward: lnwire.MilliSatoshi(amtForwarded)}
|
||||
return channeldb.HTLCAttemptInfo{
|
||||
|
Loading…
x
Reference in New Issue
Block a user