mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-29 10:09:08 +02:00
routing: refactor update payment state tests
This commit refactors the resumePayment to extract some logics back to paymentState so that the code is more testable. It also adds unit tests for paymentState, and breaks the original MPPayment tests into independent tests so that it's easier to maintain and debug. All the new tests are built using mock so that the control flow is eaiser to setup and change.
This commit is contained in:
@ -38,21 +38,53 @@ type paymentState struct {
|
||||
numShardsInFlight int
|
||||
remainingAmt lnwire.MilliSatoshi
|
||||
remainingFees lnwire.MilliSatoshi
|
||||
terminate bool
|
||||
|
||||
// terminate indicates the payment is in its final stage and no more
|
||||
// shards should be launched. This value is true if we have an HTLC
|
||||
// settled or the payment has an error.
|
||||
terminate bool
|
||||
}
|
||||
|
||||
// paymentState uses the passed payment to find the latest information we need
|
||||
// to act on every iteration of the payment loop.
|
||||
func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) (
|
||||
// terminated returns a bool to indicate there are no further actions needed
|
||||
// and we should return what we have, either the payment preimage or the
|
||||
// payment error.
|
||||
func (ps paymentState) terminated() bool {
|
||||
// If the payment is in final stage and we have no in flight shards to
|
||||
// wait result for, we consider the whole action terminated.
|
||||
return ps.terminate && ps.numShardsInFlight == 0
|
||||
}
|
||||
|
||||
// needWaitForShards returns a bool to specify whether we need to wait for the
|
||||
// outcome of the shanrdHandler.
|
||||
func (ps paymentState) needWaitForShards() bool {
|
||||
// If we have in flight shards and the payment is in final stage, we
|
||||
// need to wait for the outcomes from the shards. Or if we have no more
|
||||
// money to be sent, we need to wait for the already launched shards.
|
||||
if ps.numShardsInFlight == 0 {
|
||||
return false
|
||||
}
|
||||
return ps.terminate || ps.remainingAmt == 0
|
||||
}
|
||||
|
||||
// updatePaymentState will fetch db for the payment to find the latest
|
||||
// information we need to act on every iteration of the payment loop and update
|
||||
// the paymentState.
|
||||
func (p *paymentLifecycle) updatePaymentState() (*channeldb.MPPayment,
|
||||
*paymentState, error) {
|
||||
|
||||
// Fetch the latest payment from db.
|
||||
payment, err := p.router.cfg.Control.FetchPayment(p.identifier)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Fetch the total amount and fees that has already been sent in
|
||||
// settled and still in-flight shards.
|
||||
sentAmt, fees := payment.SentAmt()
|
||||
|
||||
// Sanity check we haven't sent a value larger than the payment amount.
|
||||
if sentAmt > p.totalAmount {
|
||||
return nil, fmt.Errorf("amount sent %v exceeds "+
|
||||
return nil, nil, fmt.Errorf("amount sent %v exceeds "+
|
||||
"total amount %v", sentAmt, p.totalAmount)
|
||||
}
|
||||
|
||||
@ -74,13 +106,15 @@ func (p *paymentLifecycle) paymentState(payment *channeldb.MPPayment) (
|
||||
// have returned with a result.
|
||||
terminate := settle != nil || failure != nil
|
||||
|
||||
activeShards := payment.InFlightHTLCs()
|
||||
return &paymentState{
|
||||
numShardsInFlight: len(activeShards),
|
||||
// Update the payment state.
|
||||
state := &paymentState{
|
||||
numShardsInFlight: len(payment.InFlightHTLCs()),
|
||||
remainingAmt: p.totalAmount - sentAmt,
|
||||
remainingFees: feeBudget,
|
||||
terminate: terminate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return payment, state, nil
|
||||
}
|
||||
|
||||
// resumePayment resumes the paymentLifecycle from the current state.
|
||||
@ -102,9 +136,7 @@ func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) {
|
||||
// If we had any existing attempts outstanding, we'll start by spinning
|
||||
// up goroutines that'll collect their results and deliver them to the
|
||||
// lifecycle loop below.
|
||||
payment, err := p.router.cfg.Control.FetchPayment(
|
||||
p.identifier,
|
||||
)
|
||||
payment, _, err := p.updatePaymentState()
|
||||
if err != nil {
|
||||
return [32]byte{}, nil, err
|
||||
}
|
||||
@ -128,34 +160,30 @@ lifecycle:
|
||||
return [32]byte{}, nil, err
|
||||
}
|
||||
|
||||
// We start every iteration by fetching the lastest state of
|
||||
// the payment from the ControlTower. This ensures that we will
|
||||
// act on the latest available information, whether we are
|
||||
// resuming an existing payment or just sent a new attempt.
|
||||
payment, err := p.router.cfg.Control.FetchPayment(
|
||||
p.identifier,
|
||||
)
|
||||
if err != nil {
|
||||
return [32]byte{}, nil, err
|
||||
}
|
||||
|
||||
// Using this latest state of the payment, calculate
|
||||
// information about our active shards and terminal conditions.
|
||||
state, err := p.paymentState(payment)
|
||||
// We update the payment state on every iteration. Since the
|
||||
// payment state is affected by multiple goroutines (ie,
|
||||
// collectResultAsync), it is NOT guaranteed that we always
|
||||
// have the latest state here. This is fine as long as the
|
||||
// state is consistent as a whole.
|
||||
payment, currentState, err := p.updatePaymentState()
|
||||
if err != nil {
|
||||
return [32]byte{}, nil, err
|
||||
}
|
||||
|
||||
log.Debugf("Payment %v in state terminate=%v, "+
|
||||
"active_shards=%v, rem_value=%v, fee_limit=%v",
|
||||
p.identifier, state.terminate, state.numShardsInFlight,
|
||||
state.remainingAmt, state.remainingFees)
|
||||
p.identifier, currentState.terminate,
|
||||
currentState.numShardsInFlight,
|
||||
currentState.remainingAmt, currentState.remainingFees,
|
||||
)
|
||||
|
||||
// TODO(yy): sanity check all the states to make sure
|
||||
// everything is expected.
|
||||
switch {
|
||||
|
||||
// We have a terminal condition and no active shards, we are
|
||||
// ready to exit.
|
||||
case state.terminate && state.numShardsInFlight == 0:
|
||||
case currentState.terminated():
|
||||
// Find the first successful shard and return
|
||||
// the preimage and route.
|
||||
for _, a := range payment.HTLCs {
|
||||
@ -170,7 +198,7 @@ lifecycle:
|
||||
// If we either reached a terminal error condition (but had
|
||||
// active shards still) or there is no remaining value to send,
|
||||
// we'll wait for a shard outcome.
|
||||
case state.terminate || state.remainingAmt == 0:
|
||||
case currentState.needWaitForShards():
|
||||
// We still have outstanding shards, so wait for a new
|
||||
// outcome to be available before re-evaluating our
|
||||
// state.
|
||||
@ -212,8 +240,9 @@ lifecycle:
|
||||
|
||||
// Create a new payment attempt from the given payment session.
|
||||
rt, err := p.paySession.RequestRoute(
|
||||
state.remainingAmt, state.remainingFees,
|
||||
uint32(state.numShardsInFlight), uint32(p.currentHeight),
|
||||
currentState.remainingAmt, currentState.remainingFees,
|
||||
uint32(currentState.numShardsInFlight),
|
||||
uint32(p.currentHeight),
|
||||
)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to find route for payment %v: %v",
|
||||
@ -227,7 +256,7 @@ lifecycle:
|
||||
// There is no route to try, and we have no active
|
||||
// shards. This means that there is no way for us to
|
||||
// send the payment, so mark it failed with no route.
|
||||
if state.numShardsInFlight == 0 {
|
||||
if currentState.numShardsInFlight == 0 {
|
||||
failureCode := routeErr.FailureReason()
|
||||
log.Debugf("Marking payment %v permanently "+
|
||||
"failed with no route: %v",
|
||||
@ -253,22 +282,11 @@ lifecycle:
|
||||
|
||||
// If this route will consume the last remeining amount to send
|
||||
// to the receiver, this will be our last shard (for now).
|
||||
lastShard := rt.ReceiverAmt() == state.remainingAmt
|
||||
lastShard := rt.ReceiverAmt() == currentState.remainingAmt
|
||||
|
||||
// We found a route to try, launch a new shard.
|
||||
attempt, outcome, err := shardHandler.launchShard(rt, lastShard)
|
||||
switch {
|
||||
// We may get a terminal error if we've processed a shard with
|
||||
// a terminal state (settled or permanent failure), while we
|
||||
// were pathfinding. We know we're in a terminal state here,
|
||||
// so we can continue and wait for our last shards to return.
|
||||
case err == channeldb.ErrPaymentTerminal:
|
||||
log.Infof("Payment %v in terminal state, abandoning "+
|
||||
"shard", p.identifier)
|
||||
|
||||
continue lifecycle
|
||||
|
||||
case err != nil:
|
||||
if err != nil {
|
||||
return [32]byte{}, nil, err
|
||||
}
|
||||
|
||||
@ -297,6 +315,7 @@ lifecycle:
|
||||
// Now that the shard was successfully sent, launch a go
|
||||
// routine that will handle its result when its back.
|
||||
shardHandler.collectResultAsync(attempt)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@ -437,12 +456,30 @@ type shardResult struct {
|
||||
}
|
||||
|
||||
// collectResultAsync launches a goroutine that will wait for the result of the
|
||||
// given HTLC attempt to be available then handle its result. Note that it will
|
||||
// fail the payment with the control tower if a terminal error is encountered.
|
||||
// given HTLC attempt to be available then handle its result. It will fail the
|
||||
// payment with the control tower if a terminal error is encountered.
|
||||
func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) {
|
||||
|
||||
// errToSend is the error to be sent to sh.shardErrors.
|
||||
var errToSend error
|
||||
|
||||
// handleResultErr is a function closure must be called using defer. It
|
||||
// finishes collecting result by updating the payment state and send
|
||||
// the error (or nil) to sh.shardErrors.
|
||||
handleResultErr := func() {
|
||||
// Send the error or quit.
|
||||
select {
|
||||
case p.shardErrors <- errToSend:
|
||||
case <-p.router.quit:
|
||||
case <-p.quit:
|
||||
}
|
||||
|
||||
p.wg.Done()
|
||||
}
|
||||
|
||||
p.wg.Add(1)
|
||||
go func() {
|
||||
defer p.wg.Done()
|
||||
defer handleResultErr()
|
||||
|
||||
// Block until the result is available.
|
||||
result, err := p.collectResult(attempt)
|
||||
@ -456,32 +493,18 @@ func (p *shardHandler) collectResultAsync(attempt *channeldb.HTLCAttemptInfo) {
|
||||
attempt.AttemptID, p.identifier, err)
|
||||
}
|
||||
|
||||
select {
|
||||
case p.shardErrors <- err:
|
||||
case <-p.router.quit:
|
||||
case <-p.quit:
|
||||
}
|
||||
// Overwrite errToSend and return.
|
||||
errToSend = err
|
||||
return
|
||||
}
|
||||
|
||||
// If a non-critical error was encountered handle it and mark
|
||||
// the payment failed if the failure was terminal.
|
||||
if result.err != nil {
|
||||
err := p.handleSendError(attempt, result.err)
|
||||
if err != nil {
|
||||
select {
|
||||
case p.shardErrors <- err:
|
||||
case <-p.router.quit:
|
||||
case <-p.quit:
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case p.shardErrors <- nil:
|
||||
case <-p.router.quit:
|
||||
case <-p.quit:
|
||||
// Overwrite errToSend and return. Notice that the
|
||||
// errToSend could be nil here.
|
||||
errToSend = p.handleSendError(attempt, result.err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
Reference in New Issue
Block a user