diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 291006310..5272ec79a 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -1,7 +1,6 @@ package channeldb import ( - "bytes" "crypto/rand" "fmt" "io" @@ -11,7 +10,6 @@ import ( "time" "github.com/btcsuite/fastsha256" - "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lntypes" ) @@ -85,9 +83,9 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + t, pControl, info.PaymentHash, info, nil, lntypes.Preimage{}, nil, ) @@ -99,9 +97,9 @@ func TestPaymentControlSwitchFail(t *testing.T) { } // Verify the status is indeed Failed. - assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusFailed) assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + t, pControl, info.PaymentHash, info, nil, lntypes.Preimage{}, &failReason, ) @@ -112,9 +110,9 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + t, pControl, info.PaymentHash, info, nil, lntypes.Preimage{}, nil, ) @@ -124,9 +122,9 @@ func TestPaymentControlSwitchFail(t *testing.T) { if err != nil { t.Fatalf("unable to send htlc message: %v", err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + t, pControl, info.PaymentHash, info, attempt, lntypes.Preimage{}, nil, ) @@ -149,8 +147,8 @@ func TestPaymentControlSwitchFail(t *testing.T) { spew.Sdump(payment.HTLCs[0].Route), err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) - assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusSucceeded) + assertPaymentInfo(t, pControl, info.PaymentHash, info, attempt, preimg, nil) // Attempt a final payment, which should now fail since the prior // payment succeed. @@ -184,9 +182,9 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { t.Fatalf("unable to send htlc message: %v", err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + t, pControl, info.PaymentHash, info, nil, lntypes.Preimage{}, nil, ) @@ -204,9 +202,9 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { if err != nil { t.Fatalf("unable to send htlc message: %v", err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + t, pControl, info.PaymentHash, info, attempt, lntypes.Preimage{}, nil, ) @@ -221,8 +219,8 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { if _, err := pControl.Success(info.PaymentHash, preimg); err != nil { t.Fatalf("error shouldn't have been received, got: %v", err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) - assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusSucceeded) + assertPaymentInfo(t, pControl, info.PaymentHash, info, attempt, preimg, nil) err = pControl.InitPayment(info.PaymentHash, info) if err != ErrAlreadyPaid { @@ -253,11 +251,7 @@ func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) - assertPaymentInfo( - t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, - nil, - ) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown) } // TestPaymentControlFailsWithoutInFlight checks that a strict payment @@ -283,10 +277,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) - assertPaymentInfo( - t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil, - ) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown) } // TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only @@ -344,9 +335,9 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { } // Verify the status is indeed Failed. - assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusFailed) assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, + t, pControl, info.PaymentHash, info, attempt, lntypes.Preimage{}, &failReason, ) } else if p.success { @@ -356,14 +347,14 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { t.Fatalf("error shouldn't have been received, got: %v", err) } - assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusSucceeded) assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, preimg, nil, + t, pControl, info.PaymentHash, info, attempt, preimg, nil, ) } else { - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight) assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, + t, pControl, info.PaymentHash, info, attempt, lntypes.Preimage{}, nil, ) } @@ -390,166 +381,76 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { } } -func assertPaymentStatus(t *testing.T, db *DB, - hash [32]byte, expStatus PaymentStatus) { +// assertPaymentStatus retrieves the status of the payment referred to by hash +// and compares it with the expected state. +func assertPaymentStatus(t *testing.T, p *PaymentControl, + hash lntypes.Hash, expStatus PaymentStatus) { t.Helper() - var paymentStatus = StatusUnknown - err := db.View(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return nil - } - - bucket := payments.Bucket(hash[:]) - if bucket == nil { - return nil - } - - // Get the existing status of this payment, if any. - paymentStatus = fetchPaymentStatus(bucket) - return nil - }) + payment, err := p.FetchPayment(hash) + if expStatus == StatusUnknown && err == ErrPaymentNotInitiated { + return + } if err != nil { - t.Fatalf("unable to fetch payment status: %v", err) + t.Fatal(err) } - if paymentStatus != expStatus { + if payment.Status != expStatus { t.Fatalf("payment status mismatch: expected %v, got %v", - expStatus, paymentStatus) + expStatus, payment.Status) } } -func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error { - b := bucket.Get(paymentCreationInfoKey) - switch { - case b == nil && c == nil: - return nil - case b == nil: - return fmt.Errorf("expected creation info not found") - case c == nil: - return fmt.Errorf("unexpected creation info found") - } - - r := bytes.NewReader(b) - c2, err := deserializePaymentCreationInfo(r) - if err != nil { - return err - } - if !reflect.DeepEqual(c, c2) { - return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v", - spew.Sdump(c), spew.Sdump(c2)) - } - - return nil -} - -func checkHTLCAttemptInfo(bucket *bbolt.Bucket, a *HTLCAttemptInfo) error { - b := bucket.Get(paymentAttemptInfoKey) - switch { - case b == nil && a == nil: - return nil - case b == nil: - return fmt.Errorf("expected attempt info not found") - case a == nil: - return fmt.Errorf("unexpected attempt info found") - } - - r := bytes.NewReader(b) - a2, err := deserializeHTLCAttemptInfo(r) - if err != nil { - return err - } - - return assertRouteEqual(&a.Route, &a2.Route) -} - -func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error { - zero := lntypes.Preimage{} - b := bucket.Get(paymentSettleInfoKey) - switch { - case b == nil && preimg == zero: - return nil - case b == nil: - return fmt.Errorf("expected preimage not found") - case preimg == zero: - return fmt.Errorf("unexpected preimage found") - } - - var pre2 lntypes.Preimage - copy(pre2[:], b[:]) - if preimg != pre2 { - return fmt.Errorf("Preimages don't match: %x vs %x", - preimg, pre2) - } - - return nil -} - -func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error { - b := bucket.Get(paymentFailInfoKey) - switch { - case b == nil && failReason == nil: - return nil - case b == nil: - return fmt.Errorf("expected fail info not found") - case failReason == nil: - return fmt.Errorf("unexpected fail info found") - } - - failReason2 := FailureReason(b[0]) - if *failReason != failReason2 { - return fmt.Errorf("Failure infos don't match: %v vs %v", - *failReason, failReason2) - } - - return nil -} - -func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash, +// assertPaymentInfo retrieves the payment referred to by hash and verifies the +// expected values. +func assertPaymentInfo(t *testing.T, p *PaymentControl, hash lntypes.Hash, c *PaymentCreationInfo, a *HTLCAttemptInfo, s lntypes.Preimage, f *FailureReason) { t.Helper() - err := db.View(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil && c == nil { - return nil - } - if payments == nil { - return fmt.Errorf("sent payments not found") - } - - bucket := payments.Bucket(hash[:]) - if bucket == nil && c == nil { - return nil - } - - if bucket == nil { - return fmt.Errorf("payment not found") - } - - if err := checkPaymentCreationInfo(bucket, c); err != nil { - return err - } - - if err := checkHTLCAttemptInfo(bucket, a); err != nil { - return err - } - - if err := checkSettleInfo(bucket, s); err != nil { - return err - } - - if err := checkFailInfo(bucket, f); err != nil { - return err - } - return nil - }) + payment, err := p.FetchPayment(hash) if err != nil { - t.Fatalf("assert payment info failed: %v", err) + t.Fatal(err) } + if !reflect.DeepEqual(payment.Info, c) { + t.Fatalf("PaymentCreationInfos don't match: %v vs %v", + spew.Sdump(payment.Info), spew.Sdump(c)) + } + + if f != nil { + if *payment.FailureReason != *f { + t.Fatal("unexpected failure reason") + } + } else { + if payment.FailureReason != nil { + t.Fatal("unexpected failure reason") + } + } + + if a == nil { + if len(payment.HTLCs) > 0 { + t.Fatal("expected no htlcs") + } + return + } + + htlc := payment.HTLCs[0] + if err := assertRouteEqual(&htlc.Route, &a.Route); err != nil { + t.Fatal("routes do not match") + } + + var zeroPreimage = lntypes.Preimage{} + if s != zeroPreimage { + if htlc.Settle.Preimage != s { + t.Fatalf("Preimages don't match: %x vs %x", + htlc.Settle.Preimage, s) + } + } else { + if htlc.Settle != nil { + t.Fatal("expected no settle info") + } + } }