diff --git a/routing/router.go b/routing/router.go index 4e8e0846b..e5c55e609 100644 --- a/routing/router.go +++ b/routing/router.go @@ -336,88 +336,8 @@ func (r *ChannelRouter) Start() error { // If any payments are still in flight, we resume, to make sure their // results are properly handled. - payments, err := r.cfg.Control.FetchInFlightPayments() - if err != nil { - return err - } - - // Before we restart existing payments and start accepting more - // payments to be made, we clean the network result store of the - // Switch. We do this here at startup to ensure no more payments can be - // made concurrently, so we know the toKeep map will be up-to-date - // until the cleaning has finished. - toKeep := make(map[uint64]struct{}) - for _, p := range payments { - for _, a := range p.HTLCs { - toKeep[a.AttemptID] = struct{}{} - } - } - - log.Debugf("Cleaning network result store.") - if err := r.cfg.Payer.CleanStore(toKeep); err != nil { - return err - } - - for _, payment := range payments { - log.Infof("Resuming payment %v", payment.Info.PaymentIdentifier) - r.wg.Add(1) - go func(payment *channeldb.MPPayment) { - defer r.wg.Done() - - // Get the hashes used for the outstanding HTLCs. - htlcs := make(map[uint64]lntypes.Hash) - for _, a := range payment.HTLCs { - a := a - - // We check whether the individual attempts - // have their HTLC hash set, if not we'll fall - // back to the overall payment hash. - hash := payment.Info.PaymentIdentifier - if a.Hash != nil { - hash = *a.Hash - } - - htlcs[a.AttemptID] = hash - } - - // Since we are not supporting creating more shards - // after a restart (only receiving the result of the - // shards already outstanding), we create a simple - // shard tracker that will map the attempt IDs to - // hashes used for the HTLCs. This will be enough also - // for AMP payments, since we only need the hashes for - // the individual HTLCs to regenerate the circuits, and - // we don't currently persist the root share necessary - // to re-derive them. - shardTracker := shards.NewSimpleShardTracker( - payment.Info.PaymentIdentifier, htlcs, - ) - - // We create a dummy, empty payment session such that - // we won't make another payment attempt when the - // result for the in-flight attempt is received. - paySession := r.cfg.SessionSource.NewPaymentSessionEmpty() - - // We pass in a non-timeout context, to indicate we - // don't need it to timeout. It will stop immediately - // after the existing attempt has finished anyway. We - // also set a zero fee limit, as no more routes should - // be tried. - noTimeout := time.Duration(0) - _, _, err := r.sendPayment( - context.Background(), 0, - payment.Info.PaymentIdentifier, noTimeout, - paySession, shardTracker, - ) - if err != nil { - log.Errorf("Resuming payment %v failed: %v.", - payment.Info.PaymentIdentifier, err) - return - } - - log.Infof("Resumed payment %v completed.", - payment.Info.PaymentIdentifier) - }(payment) + if err := r.resumePayments(); err != nil { + log.Error("Failed to resume payments during startup") } return nil @@ -1451,6 +1371,98 @@ func (r *ChannelRouter) BuildRoute(amt fn.Option[lnwire.MilliSatoshi], ) } +// resumePayments fetches inflight payments and resumes their payment +// lifecycles. +func (r *ChannelRouter) resumePayments() error { + // Get all payments that are inflight. + payments, err := r.cfg.Control.FetchInFlightPayments() + if err != nil { + return err + } + + // Before we restart existing payments and start accepting more + // payments to be made, we clean the network result store of the + // Switch. We do this here at startup to ensure no more payments can be + // made concurrently, so we know the toKeep map will be up-to-date + // until the cleaning has finished. + toKeep := make(map[uint64]struct{}) + for _, p := range payments { + for _, a := range p.HTLCs { + toKeep[a.AttemptID] = struct{}{} + } + } + + log.Debugf("Cleaning network result store.") + if err := r.cfg.Payer.CleanStore(toKeep); err != nil { + return err + } + + // launchPayment is a helper closure that handles resuming the payment. + launchPayment := func(payment *channeldb.MPPayment) { + defer r.wg.Done() + + // Get the hashes used for the outstanding HTLCs. + htlcs := make(map[uint64]lntypes.Hash) + for _, a := range payment.HTLCs { + a := a + + // We check whether the individual attempts have their + // HTLC hash set, if not we'll fall back to the overall + // payment hash. + hash := payment.Info.PaymentIdentifier + if a.Hash != nil { + hash = *a.Hash + } + + htlcs[a.AttemptID] = hash + } + + payHash := payment.Info.PaymentIdentifier + + // Since we are not supporting creating more shards after a + // restart (only receiving the result of the shards already + // outstanding), we create a simple shard tracker that will map + // the attempt IDs to hashes used for the HTLCs. This will be + // enough also for AMP payments, since we only need the hashes + // for the individual HTLCs to regenerate the circuits, and we + // don't currently persist the root share necessary to + // re-derive them. + shardTracker := shards.NewSimpleShardTracker(payHash, htlcs) + + // We create a dummy, empty payment session such that we won't + // make another payment attempt when the result for the + // in-flight attempt is received. + paySession := r.cfg.SessionSource.NewPaymentSessionEmpty() + + // We pass in a non-timeout context, to indicate we don't need + // it to timeout. It will stop immediately after the existing + // attempt has finished anyway. We also set a zero fee limit, + // as no more routes should be tried. + noTimeout := time.Duration(0) + _, _, err := r.sendPayment( + context.Background(), 0, payHash, noTimeout, paySession, + shardTracker, + ) + if err != nil { + log.Errorf("Resuming payment %v failed: %v", payHash, + err) + + return + } + + log.Infof("Resumed payment %v completed", payHash) + } + + for _, payment := range payments { + log.Infof("Resuming payment %v", payment.Info.PaymentIdentifier) + + r.wg.Add(1) + go launchPayment(payment) + } + + return nil +} + // getEdgeUnifiers returns a list of edge unifiers for the given route. func getEdgeUnifiers(source route.Vertex, hops []route.Vertex, outgoingChans map[uint64]struct{},