routing/payment_lifecycle: use ShardTracker to track shards

We'll let the payment's lifecycle register each shard it's sending with
the ShardTracker, canceling failed shards. This will be the foundation
for correct AMP derivation for each shard we'll send.
This commit is contained in:
Johan T. Halseth
2021-04-12 15:21:59 +02:00
parent 6474b253d6
commit 41ae3530a3
4 changed files with 137 additions and 46 deletions

View File

@@ -29,6 +29,7 @@ import (
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/chainview"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/routing/shards"
"github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/zpay32"
)
@@ -603,19 +604,40 @@ func (r *ChannelRouter) Start() error {
go func(payment *channeldb.MPPayment) {
defer r.wg.Done()
// Get the hashes used for the outstanding HTLCs.
htlcs := make(map[uint64]lntypes.Hash)
for _, a := range payment.HTLCs {
a := a
hash := payment.Info.PaymentHash
htlcs[a.AttemptID] = hash
}
// Since we are not supporting creating more shards
// after a restart (only receiving the result of the
// shards already outstanding), we create a simple
// shard tracker that will map the attempt IDs to
// hashes used for the HTLCs. This will be enough also
// for AMP payments, since we only need the hashes for
// the individual HTLCs to regenerate the circuits, and
// we don't currently persist the root share necessary
// to re-derive them.
shardTracker := shards.NewSimpleShardTracker(
payment.Info.PaymentHash, htlcs,
)
// We create a dummy, empty payment session such that
// we won't make another payment attempt when the
// result for the in-flight attempt is received.
paySession := r.cfg.SessionSource.NewPaymentSessionEmpty()
// We pass in a zero timeout value, to indicate we
// don't need it to timeout. It will stop immediately
// after the existing attempt has finished anyway. We
// also set a zero fee limit, as no more routes should
// be tried.
_, _, err := r.sendPayment(
payment.Info.Value, 0,
payment.Info.PaymentHash, 0, paySession,
payment.Info.Value, 0, payment.Info.PaymentHash,
0, paySession, shardTracker,
)
if err != nil {
log.Errorf("Resuming payment with hash %v "+
@@ -1770,7 +1792,7 @@ type LightningPayment struct {
func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
*route.Route, error) {
paySession, err := r.preparePayment(payment)
paySession, shardTracker, err := r.preparePayment(payment)
if err != nil {
return [32]byte{}, nil, err
}
@@ -1782,14 +1804,14 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
// for the existing attempt.
return r.sendPayment(
payment.Amount, payment.FeeLimit, payment.PaymentHash,
payment.PayAttemptTimeout, paySession,
payment.PayAttemptTimeout, paySession, shardTracker,
)
}
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
// result needs to be retrieved via the control tower.
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
paySession, err := r.preparePayment(payment)
paySession, shardTracker, err := r.preparePayment(payment)
if err != nil {
return err
}
@@ -1805,7 +1827,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
_, _, err := r.sendPayment(
payment.Amount, payment.FeeLimit, payment.PaymentHash,
payment.PayAttemptTimeout, paySession,
payment.PayAttemptTimeout, paySession, shardTracker,
)
if err != nil {
log.Errorf("Payment with hash %x failed: %v",
@@ -1841,14 +1863,14 @@ func spewPayment(payment *LightningPayment) logClosure {
// preparePayment creates the payment session and registers the payment with the
// control tower.
func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
PaymentSession, error) {
PaymentSession, shards.ShardTracker, error) {
// Before starting the HTLC routing attempt, we'll create a fresh
// payment session which will report our errors back to mission
// control.
paySession, err := r.cfg.SessionSource.NewPaymentSession(payment)
if err != nil {
return nil, err
return nil, nil, err
}
// Record this payment hash with the ControlTower, ensuring it is not
@@ -1862,12 +1884,18 @@ func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
PaymentRequest: payment.PaymentRequest,
}
// Create a new ShardTracker that we'll use during the life cycle of
// this payment.
shardTracker := shards.NewSimpleShardTracker(
payment.PaymentHash, nil,
)
err = r.cfg.Control.InitPayment(payment.PaymentHash, info)
if err != nil {
return nil, err
return nil, nil, err
}
return paySession, nil
return paySession, shardTracker, nil
}
// SendToRoute attempts to send a payment with the given hash through the
@@ -1915,14 +1943,22 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) (
}),
)
// Since the HTLC hashes and preimages are specified manually over the
// RPC for SendToRoute requests, we don't have to worry about creating
// a ShardTracker that can generate hashes for AMP payments. Instead we
// create a simple tracker that can just return the hash for the single
// shard we'll now launch.
shardTracker := shards.NewSimpleShardTracker(hash, nil)
// Launch a shard along the given route.
sh := &shardHandler{
router: r,
paymentHash: hash,
router: r,
paymentHash: hash,
shardTracker: shardTracker,
}
var shardError error
attempt, outcome, err := sh.launchShard(rt)
attempt, outcome, err := sh.launchShard(rt, false)
// With SendToRoute, it can happen that the route exceeds protocol
// constraints. Mark the payment as failed with an internal error.
@@ -2007,8 +2043,8 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, rt *route.Route) (
// the ControlTower.
func (r *ChannelRouter) sendPayment(
totalAmt, feeLimit lnwire.MilliSatoshi, paymentHash lntypes.Hash,
timeout time.Duration,
paySession PaymentSession) ([32]byte, *route.Route, error) {
timeout time.Duration, paySession PaymentSession,
shardTracker shards.ShardTracker) ([32]byte, *route.Route, error) {
// We'll also fetch the current block height so we can properly
// calculate the required HTLC time locks within the route.
@@ -2025,6 +2061,7 @@ func (r *ChannelRouter) sendPayment(
feeLimit: feeLimit,
paymentHash: paymentHash,
paySession: paySession,
shardTracker: shardTracker,
currentHeight: currentHeight,
}