mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-28 02:06:57 +02:00
routing+lnrpc: subscribe payment stream before sending it
This commit moves the subscription of a given payment before it's been sent so we won't miss any events.
This commit is contained in:
@@ -318,27 +318,46 @@ func (s *Server) SendPaymentV2(req *SendPaymentRequest,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.cfg.Router.SendPaymentAsync(payment)
|
// Get the payment hash.
|
||||||
|
payHash := payment.Identifier()
|
||||||
|
|
||||||
|
// Init the payment in db.
|
||||||
|
paySession, shardTracker, err := s.cfg.Router.PreparePayment(payment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Transform user errors to grpc code.
|
|
||||||
if err == channeldb.ErrPaymentInFlight ||
|
|
||||||
err == channeldb.ErrAlreadyPaid {
|
|
||||||
|
|
||||||
log.Debugf("SendPayment async result for payment %x: %v",
|
|
||||||
payment.Identifier(), err)
|
|
||||||
|
|
||||||
return status.Error(
|
|
||||||
codes.AlreadyExists, err.Error(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Errorf("SendPayment async error for payment %x: %v",
|
|
||||||
payment.Identifier(), err)
|
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.trackPayment(payment.Identifier(), stream, req.NoInflightUpdates)
|
// Subscribe to the payment before sending it to make sure we won't
|
||||||
|
// miss events.
|
||||||
|
sub, err := s.subscribePayment(payHash)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the payment.
|
||||||
|
err = s.cfg.Router.SendPaymentAsync(payment, paySession, shardTracker)
|
||||||
|
if err == nil {
|
||||||
|
// If the payment was sent successfully, we can start tracking
|
||||||
|
// the events.
|
||||||
|
return s.trackPayment(
|
||||||
|
sub, payHash, stream, req.NoInflightUpdates,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, transform user errors to grpc code.
|
||||||
|
if errors.Is(err, channeldb.ErrPaymentInFlight) ||
|
||||||
|
errors.Is(err, channeldb.ErrAlreadyPaid) {
|
||||||
|
|
||||||
|
log.Debugf("SendPayment async result for payment %x: %v",
|
||||||
|
payment.Identifier(), err)
|
||||||
|
|
||||||
|
return status.Error(codes.AlreadyExists, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Errorf("SendPayment async error for payment %x: %v",
|
||||||
|
payment.Identifier(), err)
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// EstimateRouteFee allows callers to obtain a lower bound w.r.t how much it
|
// EstimateRouteFee allows callers to obtain a lower bound w.r.t how much it
|
||||||
@@ -800,34 +819,48 @@ func getMsatPairValue(msatValue lnwire.MilliSatoshi,
|
|||||||
func (s *Server) TrackPaymentV2(request *TrackPaymentRequest,
|
func (s *Server) TrackPaymentV2(request *TrackPaymentRequest,
|
||||||
stream Router_TrackPaymentV2Server) error {
|
stream Router_TrackPaymentV2Server) error {
|
||||||
|
|
||||||
paymentHash, err := lntypes.MakeHash(request.PaymentHash)
|
payHash, err := lntypes.MakeHash(request.PaymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("TrackPayment called for payment %v", paymentHash)
|
log.Debugf("TrackPayment called for payment %v", payHash)
|
||||||
|
|
||||||
return s.trackPayment(paymentHash, stream, request.NoInflightUpdates)
|
// Make the subscription.
|
||||||
}
|
sub, err := s.subscribePayment(payHash)
|
||||||
|
if err != nil {
|
||||||
// trackPayment writes payment status updates to the provided stream.
|
|
||||||
func (s *Server) trackPayment(identifier lntypes.Hash,
|
|
||||||
stream Router_TrackPaymentV2Server, noInflightUpdates bool) error {
|
|
||||||
|
|
||||||
router := s.cfg.RouterBackend
|
|
||||||
|
|
||||||
// Subscribe to the outcome of this payment.
|
|
||||||
subscription, err := router.Tower.SubscribePayment(identifier)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case err == channeldb.ErrPaymentNotInitiated:
|
|
||||||
return status.Error(codes.NotFound, err.Error())
|
|
||||||
case err != nil:
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return s.trackPayment(sub, payHash, stream, request.NoInflightUpdates)
|
||||||
|
}
|
||||||
|
|
||||||
|
// subscribePayment subscribes to the payment updates for the given payment
|
||||||
|
// hash.
|
||||||
|
func (s *Server) subscribePayment(identifier lntypes.Hash) (
|
||||||
|
routing.ControlTowerSubscriber, error) {
|
||||||
|
|
||||||
|
// Make the subscription.
|
||||||
|
router := s.cfg.RouterBackend
|
||||||
|
sub, err := router.Tower.SubscribePayment(identifier)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err == channeldb.ErrPaymentNotInitiated:
|
||||||
|
return nil, status.Error(codes.NotFound, err.Error())
|
||||||
|
case err != nil:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sub, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// trackPayment writes payment status updates to the provided stream.
|
||||||
|
func (s *Server) trackPayment(subscription routing.ControlTowerSubscriber,
|
||||||
|
identifier lntypes.Hash, stream Router_TrackPaymentV2Server,
|
||||||
|
noInflightUpdates bool) error {
|
||||||
|
|
||||||
// Stream updates to the client.
|
// Stream updates to the client.
|
||||||
err = s.trackPaymentStream(
|
err := s.trackPaymentStream(
|
||||||
stream.Context(), subscription, noInflightUpdates, stream.Send,
|
stream.Context(), subscription, noInflightUpdates, stream.Send,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -2044,7 +2044,7 @@ func (l *LightningPayment) Identifier() [32]byte {
|
|||||||
func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
|
func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
|
||||||
*route.Route, error) {
|
*route.Route, error) {
|
||||||
|
|
||||||
paySession, shardTracker, err := r.preparePayment(payment)
|
paySession, shardTracker, err := r.PreparePayment(payment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return [32]byte{}, nil, err
|
return [32]byte{}, nil, err
|
||||||
}
|
}
|
||||||
@@ -2062,11 +2062,8 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte,
|
|||||||
|
|
||||||
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
|
// SendPaymentAsync is the non-blocking version of SendPayment. The payment
|
||||||
// result needs to be retrieved via the control tower.
|
// result needs to be retrieved via the control tower.
|
||||||
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
|
func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment,
|
||||||
paySession, shardTracker, err := r.preparePayment(payment)
|
ps PaymentSession, st shards.ShardTracker) error {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since this is the first time this payment is being made, we pass nil
|
// Since this is the first time this payment is being made, we pass nil
|
||||||
// for the existing attempt.
|
// for the existing attempt.
|
||||||
@@ -2079,7 +2076,7 @@ func (r *ChannelRouter) SendPaymentAsync(payment *LightningPayment) error {
|
|||||||
|
|
||||||
_, _, err := r.sendPayment(
|
_, _, err := r.sendPayment(
|
||||||
payment.FeeLimit, payment.Identifier(),
|
payment.FeeLimit, payment.Identifier(),
|
||||||
payment.PayAttemptTimeout, paySession, shardTracker,
|
payment.PayAttemptTimeout, ps, st,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Payment %x failed: %v",
|
log.Errorf("Payment %x failed: %v",
|
||||||
@@ -2111,9 +2108,9 @@ func spewPayment(payment *LightningPayment) logClosure {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// preparePayment creates the payment session and registers the payment with the
|
// PreparePayment creates the payment session and registers the payment with the
|
||||||
// control tower.
|
// control tower.
|
||||||
func (r *ChannelRouter) preparePayment(payment *LightningPayment) (
|
func (r *ChannelRouter) PreparePayment(payment *LightningPayment) (
|
||||||
PaymentSession, shards.ShardTracker, error) {
|
PaymentSession, shards.ShardTracker, error) {
|
||||||
|
|
||||||
// Before starting the HTLC routing attempt, we'll create a fresh
|
// Before starting the HTLC routing attempt, we'll create a fresh
|
||||||
|
Reference in New Issue
Block a user