mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-15 22:30:33 +02:00
channeldb: return updated payment on attempt update
Similar to what is done for SettleAttempt. Co-authored-by: Johan T. Halseth <johanth@gmail.com>
This commit is contained in:
parent
351d8e174c
commit
278915e598
@ -186,38 +186,39 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
|
|||||||
// RegisterAttempt atomically records the provided HTLCAttemptInfo to the
|
// RegisterAttempt atomically records the provided HTLCAttemptInfo to the
|
||||||
// DB.
|
// DB.
|
||||||
func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
||||||
attempt *HTLCAttemptInfo) error {
|
attempt *HTLCAttemptInfo) (*MPPayment, error) {
|
||||||
|
|
||||||
// Serialize the information before opening the db transaction.
|
// Serialize the information before opening the db transaction.
|
||||||
var a bytes.Buffer
|
var a bytes.Buffer
|
||||||
err := serializeHTLCAttemptInfo(&a, attempt)
|
err := serializeHTLCAttemptInfo(&a, attempt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
htlcInfoBytes := a.Bytes()
|
htlcInfoBytes := a.Bytes()
|
||||||
|
|
||||||
htlcIDBytes := make([]byte, 8)
|
htlcIDBytes := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(htlcIDBytes, attempt.AttemptID)
|
binary.BigEndian.PutUint64(htlcIDBytes, attempt.AttemptID)
|
||||||
|
|
||||||
return kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
|
var payment *MPPayment
|
||||||
|
err = kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error {
|
||||||
bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
|
bucket, err := fetchPaymentBucketUpdate(tx, paymentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
payment, err := fetchPayment(bucket)
|
p, err := fetchPayment(bucket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the payment is in-flight.
|
// Ensure the payment is in-flight.
|
||||||
if err := ensureInFlight(payment); err != nil {
|
if err := ensureInFlight(p); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// We cannot register a new attempt if the payment already has
|
// We cannot register a new attempt if the payment already has
|
||||||
// reached a terminal condition:
|
// reached a terminal condition:
|
||||||
settle, fail := payment.TerminalInfo()
|
settle, fail := p.TerminalInfo()
|
||||||
if settle != nil || fail != nil {
|
if settle != nil || fail != nil {
|
||||||
return ErrPaymentTerminal
|
return ErrPaymentTerminal
|
||||||
}
|
}
|
||||||
@ -225,7 +226,7 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
|||||||
// Make sure any existing shards match the new one with regards
|
// Make sure any existing shards match the new one with regards
|
||||||
// to MPP options.
|
// to MPP options.
|
||||||
mpp := attempt.Route.FinalHop().MPP
|
mpp := attempt.Route.FinalHop().MPP
|
||||||
for _, h := range payment.InFlightHTLCs() {
|
for _, h := range p.InFlightHTLCs() {
|
||||||
hMpp := h.Route.FinalHop().MPP
|
hMpp := h.Route.FinalHop().MPP
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
@ -258,13 +259,13 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
|||||||
// If this is a non-MPP attempt, it must match the total amount
|
// If this is a non-MPP attempt, it must match the total amount
|
||||||
// exactly.
|
// exactly.
|
||||||
amt := attempt.Route.ReceiverAmt()
|
amt := attempt.Route.ReceiverAmt()
|
||||||
if mpp == nil && amt != payment.Info.Value {
|
if mpp == nil && amt != p.Info.Value {
|
||||||
return ErrValueMismatch
|
return ErrValueMismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we aren't sending more than the total payment amount.
|
// Ensure we aren't sending more than the total payment amount.
|
||||||
sentAmt, _ := payment.SentAmt()
|
sentAmt, _ := p.SentAmt()
|
||||||
if sentAmt+amt > payment.Info.Value {
|
if sentAmt+amt > p.Info.Value {
|
||||||
return ErrValueExceedsAmt
|
return ErrValueExceedsAmt
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,8 +283,20 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return htlcBucket.Put(htlcAttemptInfoKey, htlcInfoBytes)
|
err = htlcBucket.Put(htlcAttemptInfoKey, htlcInfoBytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve attempt info for the notification.
|
||||||
|
payment, err = fetchPayment(bucket)
|
||||||
|
return err
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return payment, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SettleAttempt marks the given attempt settled with the preimage. If this is
|
// SettleAttempt marks the given attempt settled with the preimage. If this is
|
||||||
@ -307,16 +320,15 @@ func (p *PaymentControl) SettleAttempt(hash lntypes.Hash,
|
|||||||
|
|
||||||
// FailAttempt marks the given payment attempt failed.
|
// FailAttempt marks the given payment attempt failed.
|
||||||
func (p *PaymentControl) FailAttempt(hash lntypes.Hash,
|
func (p *PaymentControl) FailAttempt(hash lntypes.Hash,
|
||||||
attemptID uint64, failInfo *HTLCFailInfo) error {
|
attemptID uint64, failInfo *HTLCFailInfo) (*MPPayment, error) {
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := serializeHTLCFailInfo(&b, failInfo); err != nil {
|
if err := serializeHTLCFailInfo(&b, failInfo); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
failBytes := b.Bytes()
|
failBytes := b.Bytes()
|
||||||
|
|
||||||
_, err := p.updateHtlcKey(hash, attemptID, htlcFailInfoKey, failBytes)
|
return p.updateHtlcKey(hash, attemptID, htlcFailInfoKey, failBytes)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateHtlcKey updates a database key for the specified htlc.
|
// updateHtlcKey updates a database key for the specified htlc.
|
||||||
|
@ -117,13 +117,13 @@ func TestPaymentControlSwitchFail(t *testing.T) {
|
|||||||
// Record a new attempt. In this test scenario, the attempt fails.
|
// Record a new attempt. In this test scenario, the attempt fails.
|
||||||
// However, this is not communicated to control tower in the current
|
// However, this is not communicated to control tower in the current
|
||||||
// implementation. It only registers the initiation of the attempt.
|
// implementation. It only registers the initiation of the attempt.
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to register attempt: %v", err)
|
t.Fatalf("unable to register attempt: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
htlcReason := HTLCFailUnreadable
|
htlcReason := HTLCFailUnreadable
|
||||||
err = pControl.FailAttempt(
|
_, err = pControl.FailAttempt(
|
||||||
info.PaymentHash, attempt.AttemptID,
|
info.PaymentHash, attempt.AttemptID,
|
||||||
&HTLCFailInfo{
|
&HTLCFailInfo{
|
||||||
Reason: htlcReason,
|
Reason: htlcReason,
|
||||||
@ -143,7 +143,7 @@ func TestPaymentControlSwitchFail(t *testing.T) {
|
|||||||
|
|
||||||
// Record another attempt.
|
// Record another attempt.
|
||||||
attempt.AttemptID = 1
|
attempt.AttemptID = 1
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
}
|
}
|
||||||
@ -236,7 +236,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Record an attempt.
|
// Record an attempt.
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
}
|
}
|
||||||
@ -375,7 +375,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
}
|
}
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
}
|
}
|
||||||
@ -387,7 +387,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
|
|||||||
if p.failed {
|
if p.failed {
|
||||||
// Fail the payment attempt.
|
// Fail the payment attempt.
|
||||||
htlcFailure := HTLCFailUnreadable
|
htlcFailure := HTLCFailUnreadable
|
||||||
err := pControl.FailAttempt(
|
_, err := pControl.FailAttempt(
|
||||||
info.PaymentHash, attempt.AttemptID,
|
info.PaymentHash, attempt.AttemptID,
|
||||||
&HTLCFailInfo{
|
&HTLCFailInfo{
|
||||||
Reason: htlcFailure,
|
Reason: htlcFailure,
|
||||||
@ -520,7 +520,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
a.AttemptID = i
|
a.AttemptID = i
|
||||||
attempts = append(attempts, &a)
|
attempts = append(attempts, &a)
|
||||||
|
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, &a)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, &a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
}
|
}
|
||||||
@ -541,7 +541,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
// will be too large.
|
// will be too large.
|
||||||
b := *attempt
|
b := *attempt
|
||||||
b.AttemptID = 3
|
b.AttemptID = 3
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
if err != ErrValueExceedsAmt {
|
if err != ErrValueExceedsAmt {
|
||||||
t.Fatalf("expected ErrValueExceedsAmt, got: %v",
|
t.Fatalf("expected ErrValueExceedsAmt, got: %v",
|
||||||
err)
|
err)
|
||||||
@ -550,7 +550,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
// Fail the second attempt.
|
// Fail the second attempt.
|
||||||
a := attempts[1]
|
a := attempts[1]
|
||||||
htlcFail := HTLCFailUnreadable
|
htlcFail := HTLCFailUnreadable
|
||||||
err = pControl.FailAttempt(
|
_, err = pControl.FailAttempt(
|
||||||
info.PaymentHash, a.AttemptID,
|
info.PaymentHash, a.AttemptID,
|
||||||
&HTLCFailInfo{
|
&HTLCFailInfo{
|
||||||
Reason: htlcFail,
|
Reason: htlcFail,
|
||||||
@ -596,7 +596,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
t, pControl, info.PaymentHash, info, nil, htlc,
|
t, pControl, info.PaymentHash, info, nil, htlc,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
err := pControl.FailAttempt(
|
_, err := pControl.FailAttempt(
|
||||||
info.PaymentHash, a.AttemptID,
|
info.PaymentHash, a.AttemptID,
|
||||||
&HTLCFailInfo{
|
&HTLCFailInfo{
|
||||||
Reason: htlcFail,
|
Reason: htlcFail,
|
||||||
@ -634,7 +634,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
// that the payment has reached a terminal condition.
|
// that the payment has reached a terminal condition.
|
||||||
b = *attempt
|
b = *attempt
|
||||||
b.AttemptID = 3
|
b.AttemptID = 3
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
if err != ErrPaymentTerminal {
|
if err != ErrPaymentTerminal {
|
||||||
t.Fatalf("expected ErrPaymentTerminal, got: %v", err)
|
t.Fatalf("expected ErrPaymentTerminal, got: %v", err)
|
||||||
}
|
}
|
||||||
@ -666,7 +666,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// Fail the attempt.
|
// Fail the attempt.
|
||||||
err := pControl.FailAttempt(
|
_, err := pControl.FailAttempt(
|
||||||
info.PaymentHash, a.AttemptID,
|
info.PaymentHash, a.AttemptID,
|
||||||
&HTLCFailInfo{
|
&HTLCFailInfo{
|
||||||
Reason: htlcFail,
|
Reason: htlcFail,
|
||||||
@ -708,7 +708,7 @@ func TestPaymentControlMultiShard(t *testing.T) {
|
|||||||
assertPaymentStatus(t, pControl, info.PaymentHash, finalStatus)
|
assertPaymentStatus(t, pControl, info.PaymentHash, finalStatus)
|
||||||
|
|
||||||
// Finally assert we cannot register more attempts.
|
// Finally assert we cannot register more attempts.
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
if err != expRegErr {
|
if err != expRegErr {
|
||||||
t.Fatalf("expected error %v, got: %v", expRegErr, err)
|
t.Fatalf("expected error %v, got: %v", expRegErr, err)
|
||||||
}
|
}
|
||||||
@ -756,7 +756,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) {
|
|||||||
info.Value, [32]byte{1},
|
info.Value, [32]byte{1},
|
||||||
)
|
)
|
||||||
|
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
}
|
}
|
||||||
@ -765,7 +765,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) {
|
|||||||
b := *attempt
|
b := *attempt
|
||||||
b.AttemptID = 1
|
b.AttemptID = 1
|
||||||
b.Route.FinalHop().MPP = nil
|
b.Route.FinalHop().MPP = nil
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
if err != ErrMPPayment {
|
if err != ErrMPPayment {
|
||||||
t.Fatalf("expected ErrMPPayment, got: %v", err)
|
t.Fatalf("expected ErrMPPayment, got: %v", err)
|
||||||
}
|
}
|
||||||
@ -774,7 +774,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) {
|
|||||||
b.Route.FinalHop().MPP = record.NewMPP(
|
b.Route.FinalHop().MPP = record.NewMPP(
|
||||||
info.Value, [32]byte{2},
|
info.Value, [32]byte{2},
|
||||||
)
|
)
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
if err != ErrMPPPaymentAddrMismatch {
|
if err != ErrMPPPaymentAddrMismatch {
|
||||||
t.Fatalf("expected ErrMPPPaymentAddrMismatch, got: %v", err)
|
t.Fatalf("expected ErrMPPPaymentAddrMismatch, got: %v", err)
|
||||||
}
|
}
|
||||||
@ -783,7 +783,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) {
|
|||||||
b.Route.FinalHop().MPP = record.NewMPP(
|
b.Route.FinalHop().MPP = record.NewMPP(
|
||||||
info.Value/2, [32]byte{1},
|
info.Value/2, [32]byte{1},
|
||||||
)
|
)
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
if err != ErrMPPTotalAmountMismatch {
|
if err != ErrMPPTotalAmountMismatch {
|
||||||
t.Fatalf("expected ErrMPPTotalAmountMismatch, got: %v", err)
|
t.Fatalf("expected ErrMPPTotalAmountMismatch, got: %v", err)
|
||||||
}
|
}
|
||||||
@ -801,7 +801,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
attempt.Route.FinalHop().MPP = nil
|
attempt.Route.FinalHop().MPP = nil
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
t.Fatalf("unable to send htlc message: %v", err)
|
||||||
}
|
}
|
||||||
@ -813,7 +813,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) {
|
|||||||
info.Value, [32]byte{1},
|
info.Value, [32]byte{1},
|
||||||
)
|
)
|
||||||
|
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
_, err = pControl.RegisterAttempt(info.PaymentHash, &b)
|
||||||
if err != ErrNonMPPayment {
|
if err != ErrNonMPPayment {
|
||||||
t.Fatalf("expected ErrNonMPPayment, got: %v", err)
|
t.Fatalf("expected ErrNonMPPayment, got: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,8 @@ func (p *controlTower) InitPayment(paymentHash lntypes.Hash,
|
|||||||
func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash,
|
func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash,
|
||||||
attempt *channeldb.HTLCAttemptInfo) error {
|
attempt *channeldb.HTLCAttemptInfo) error {
|
||||||
|
|
||||||
return p.db.RegisterAttempt(paymentHash, attempt)
|
_, err := p.db.RegisterAttempt(paymentHash, attempt)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SettleAttempt marks the given attempt settled with the preimage. If
|
// SettleAttempt marks the given attempt settled with the preimage. If
|
||||||
@ -133,7 +134,8 @@ func (p *controlTower) SettleAttempt(paymentHash lntypes.Hash,
|
|||||||
func (p *controlTower) FailAttempt(paymentHash lntypes.Hash,
|
func (p *controlTower) FailAttempt(paymentHash lntypes.Hash,
|
||||||
attemptID uint64, failInfo *channeldb.HTLCFailInfo) error {
|
attemptID uint64, failInfo *channeldb.HTLCFailInfo) error {
|
||||||
|
|
||||||
return p.db.FailAttempt(paymentHash, attemptID, failInfo)
|
_, err := p.db.FailAttempt(paymentHash, attemptID, failInfo)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchPayment fetches the payment corresponding to the given payment hash.
|
// FetchPayment fetches the payment corresponding to the given payment hash.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user