diff --git a/channeldb/db.go b/channeldb/db.go index 0035e9831..939bbdcb9 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -67,6 +67,13 @@ var ( number: 4, migration: migrateEdgePolicies, }, + { + // The DB version where we persist each attempt to send + // an HTLC to a payment hash, and track whether the + // payment is in-flight, succeeded, or failed. + number: 5, + migration: paymentStatusesMigration, + }, } // Big endian is the preferred byte order, due to cursor scans over diff --git a/channeldb/meta_test.go b/channeldb/meta_test.go index dbed0a2be..c98999b47 100644 --- a/channeldb/meta_test.go +++ b/channeldb/meta_test.go @@ -9,6 +9,59 @@ import ( "github.com/go-errors/errors" ) +// applyMigration is a helper test function that encapsulates the general steps +// which are needed to properly check the result of applying migration function. +func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), + migrationFunc migration, shouldFail bool) { + + cdb, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatal(err) + } + + // beforeMigration usually used for populating the database + // with test data. + beforeMigration(cdb) + + // Create test meta info with zero database version and put it on disk. + // Than creating the version list pretending that new version was added. + meta := &Meta{DbVersionNumber: 0} + if err := cdb.PutMeta(meta); err != nil { + t.Fatalf("unable to store meta data: %v", err) + } + + versions := []version{ + { + number: 0, + migration: nil, + }, + { + number: 1, + migration: migrationFunc, + }, + } + + defer func() { + if r := recover(); r != nil { + err = errors.New(r) + } + + if err == nil && shouldFail { + t.Fatal("error wasn't received on migration stage") + } else if err != nil && !shouldFail { + t.Fatal("error was received on migration stage") + } + + // afterMigration usually used for checking the database state and + // throwing the error if something went wrong. + afterMigration(cdb) + }() + + // Sync with the latest version - applying migration function. + err = cdb.syncVersions(versions) +} + // TestVersionFetchPut checks the propernces of fetch/put methods // and also initialization of meta data in case if don't have any in // database. @@ -118,59 +171,8 @@ func TestGlobalVersionList(t *testing.T) { } } -// applyMigration is a helper test function that encapsulates the general steps -// which are needed to properly check the result of applying migration function. -func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), - migrationFunc migration, shouldFail bool) { - - cdb, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatal(err) - } - - // beforeMigration usually used for populating the database - // with test data. - beforeMigration(cdb) - - // Create test meta info with zero database version and put it on disk. - // Than creating the version list pretending that new version was added. - meta := &Meta{DbVersionNumber: 0} - if err := cdb.PutMeta(meta); err != nil { - t.Fatalf("unable to store meta data: %v", err) - } - - versions := []version{ - { - number: 0, - migration: nil, - }, - { - number: 1, - migration: migrationFunc, - }, - } - - defer func() { - if r := recover(); r != nil { - err = errors.New(r) - } - - if err == nil && shouldFail { - t.Fatal("error wasn't received on migration stage") - } else if err != nil && !shouldFail { - t.Fatal("error was received on migration stage") - } - - // afterMigration usually used for checking the database state and - // throwing the error if something went wrong. - afterMigration(cdb) - }() - - // Sync with the latest version - applying migration function. - err = cdb.syncVersions(versions) -} - +// TestMigrationWithPanic asserts that if migration logic panics, we will return +// to the original state unaltered. func TestMigrationWithPanic(t *testing.T) { t.Parallel() @@ -242,6 +244,8 @@ func TestMigrationWithPanic(t *testing.T) { true) } +// TestMigrationWithFatal asserts that migrations which fail do not modify the +// database. func TestMigrationWithFatal(t *testing.T) { t.Parallel() @@ -312,6 +316,8 @@ func TestMigrationWithFatal(t *testing.T) { true) } +// TestMigrationWithoutErrors asserts that a successful migration has its +// changes applied to the database. func TestMigrationWithoutErrors(t *testing.T) { t.Parallel() diff --git a/channeldb/migrations.go b/channeldb/migrations.go index e8b658ed8..c7beb638a 100644 --- a/channeldb/migrations.go +++ b/channeldb/migrations.go @@ -2,6 +2,8 @@ package channeldb import ( "bytes" + "crypto/sha256" + "encoding/binary" "fmt" "github.com/coreos/bbolt" @@ -373,3 +375,86 @@ func migrateEdgePolicies(tx *bolt.Tx) error { return nil } + +// paymentStatusesMigration is a database migration intended for adding payment +// statuses for each existing payment entity in bucket to be able control +// transitions of statuses and prevent cases such as double payment +func paymentStatusesMigration(tx *bolt.Tx) error { + // Get the bucket dedicated to storing statuses of payments, + // where a key is payment hash, value is payment status. + paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) + if err != nil { + return err + } + + log.Infof("Migrating database to support payment statuses") + + circuitAddKey := []byte("circuit-adds") + circuits := tx.Bucket(circuitAddKey) + if circuits != nil { + log.Infof("Marking all known circuits with status InFlight") + + err = circuits.ForEach(func(k, v []byte) error { + // Parse the first 8 bytes as the short chan ID for the + // circuit. We'll skip all short chan IDs are not + // locally initiated, which includes all non-zero short + // chan ids. + chanID := binary.BigEndian.Uint64(k[:8]) + if chanID != 0 { + return nil + } + + // The payment hash is the third item in the serialized + // payment circuit. The first two items are an AddRef + // (10 bytes) and the incoming circuit key (16 bytes). + const payHashOffset = 10 + 16 + + paymentHash := v[payHashOffset : payHashOffset+32] + + return paymentStatuses.Put( + paymentHash[:], StatusInFlight.Bytes(), + ) + }) + if err != nil { + return err + } + } + + log.Infof("Marking all existing payments with status Completed") + + // Get the bucket dedicated to storing payments + bucket := tx.Bucket(paymentBucket) + if bucket == nil { + return nil + } + + // For each payment in the bucket, deserialize the payment and mark it + // as completed. + err = bucket.ForEach(func(k, v []byte) error { + // Ignores if it is sub-bucket. + if v == nil { + return nil + } + + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + // Calculate payment hash for current payment. + paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) + + // Update status for current payment to completed. If it fails, + // the migration is aborted and the payment bucket is returned + // to its previous state. + return paymentStatuses.Put(paymentHash[:], StatusCompleted.Bytes()) + }) + if err != nil { + return err + } + + log.Infof("Migration of payment statuses complete!") + + return nil +} diff --git a/channeldb/migrations_test.go b/channeldb/migrations_test.go new file mode 100644 index 000000000..6fbefd0ca --- /dev/null +++ b/channeldb/migrations_test.go @@ -0,0 +1,191 @@ +package channeldb + +import ( + "crypto/sha256" + "encoding/binary" + "testing" + + "github.com/coreos/bbolt" +) + +// TestPaymentStatusesMigration checks that already completed payments will have +// their payment statuses set to Completed after the migration. +func TestPaymentStatusesMigration(t *testing.T) { + t.Parallel() + + fakePayment := makeFakePayment() + paymentHash := sha256.Sum256(fakePayment.PaymentPreimage[:]) + + // Add fake payment to test database, verifying that it was created, + // that we have only one payment, and its status is not "Completed". + beforeMigrationFunc := func(d *DB) { + if err := d.AddPayment(fakePayment); err != nil { + t.Fatalf("unable to add payment: %v", err) + } + + payments, err := d.FetchAllPayments() + if err != nil { + t.Fatalf("unable to fetch payments: %v", err) + } + + if len(payments) != 1 { + t.Fatalf("wrong qty of paymets: expected 1, got %v", + len(payments)) + } + + paymentStatus, err := d.FetchPaymentStatus(paymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + // We should receive default status if we have any in database. + if paymentStatus != StatusGrounded { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusGrounded.String(), paymentStatus.String()) + } + + // Lastly, we'll add a locally-sourced circuit and + // non-locally-sourced circuit to the circuit map. The + // locally-sourced payment should end up with an InFlight + // status, while the other should remain unchanged, which + // defaults to Grounded. + err = d.Update(func(tx *bolt.Tx) error { + circuits, err := tx.CreateBucketIfNotExists( + []byte("circuit-adds"), + ) + if err != nil { + return err + } + + groundedKey := make([]byte, 16) + binary.BigEndian.PutUint64(groundedKey[:8], 1) + binary.BigEndian.PutUint64(groundedKey[8:], 1) + + // Generated using TestHalfCircuitSerialization with nil + // ErrorEncrypter, which is the case for locally-sourced + // payments. No payment status should end up being set + // for this circuit, since the short channel id of the + // key is non-zero (e.g., a forwarded circuit). This + // will default it to Grounded. + groundedCircuit := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + // start payment hash + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // end payment hash + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, + 0x42, 0x40, 0x00, + } + + err = circuits.Put(groundedKey, groundedCircuit) + if err != nil { + return err + } + + inFlightKey := make([]byte, 16) + binary.BigEndian.PutUint64(inFlightKey[:8], 0) + binary.BigEndian.PutUint64(inFlightKey[8:], 1) + + // Generated using TestHalfCircuitSerialization with nil + // ErrorEncrypter, which is not the case for forwarded + // payments, but should have no impact on the + // correctness of the test. The payment status for this + // circuit should be set to InFlight, since the short + // channel id in the key is 0 (sourceHop). + inFlightCircuit := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + // start payment hash + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // end payment hash + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, + 0x42, 0x40, 0x00, + } + + return circuits.Put(inFlightKey, inFlightCircuit) + }) + if err != nil { + t.Fatalf("unable to add circuit map entry: %v", err) + } + } + + // Verify that the created payment status is "Completed" for our one + // fake payment. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'paymentStatusesMigration' wasn't applied") + } + + // Check that our completed payments were migrated. + paymentStatus, err := d.FetchPaymentStatus(paymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusCompleted { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusCompleted.String(), paymentStatus.String()) + } + + inFlightHash := [32]byte{ + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + + // Check that the locally sourced payment was transitioned to + // InFlight. + paymentStatus, err = d.FetchPaymentStatus(inFlightHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusInFlight { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusInFlight.String(), paymentStatus.String()) + } + + groundedHash := [32]byte{ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + + // Check that non-locally sourced payments remain in the default + // Grounded state. + paymentStatus, err = d.FetchPaymentStatus(groundedHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusGrounded { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusGrounded.String(), paymentStatus.String()) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + paymentStatusesMigration, + false) +} diff --git a/channeldb/payments.go b/channeldb/payments.go index 0e5f47b9f..7d32f20c9 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -3,6 +3,7 @@ package channeldb import ( "bytes" "encoding/binary" + "errors" "io" "github.com/coreos/bbolt" @@ -17,8 +18,65 @@ var ( // which is a monotonically increasing uint64. BoltDB's sequence // feature is used for generating monotonically increasing id. paymentBucket = []byte("payments") + + // paymentStatusBucket is the name of the bucket within the database that + // stores the status of a payment indexed by the payment's preimage. + paymentStatusBucket = []byte("payment-status") ) +// PaymentStatus represent current status of payment +type PaymentStatus byte + +const ( + // StatusGrounded is the status where a payment has never been + // initiated, or has been initiated and received an intermittent + // failure. + StatusGrounded PaymentStatus = 0 + + // StatusInFlight is the status where a payment has been initiated, but + // a response has not been received. + StatusInFlight PaymentStatus = 1 + + // StatusCompleted is the status where a payment has been initiated and + // the payment was completed successfully. + StatusCompleted PaymentStatus = 2 +) + +// Bytes returns status as slice of bytes. +func (ps PaymentStatus) Bytes() []byte { + return []byte{byte(ps)} +} + +// FromBytes sets status from slice of bytes. +func (ps *PaymentStatus) FromBytes(status []byte) error { + if len(status) != 1 { + return errors.New("payment status is empty") + } + + switch PaymentStatus(status[0]) { + case StatusGrounded, StatusInFlight, StatusCompleted: + *ps = PaymentStatus(status[0]) + default: + return errors.New("unknown payment status") + } + + return nil +} + +// String returns readable representation of payment status. +func (ps PaymentStatus) String() string { + switch ps { + case StatusGrounded: + return "Grounded" + case StatusInFlight: + return "In Flight" + case StatusCompleted: + return "Completed" + default: + return "Unknown" + } +} + // OutgoingPayment represents a successful payment between the daemon and a // remote node. Details such as the total fee paid, and the time of the payment // are stored. @@ -129,6 +187,68 @@ func (db *DB) DeleteAllPayments() error { }) } +// UpdatePaymentStatus sets the payment status for outgoing/finished payments in +// local database. +func (db *DB) UpdatePaymentStatus(paymentHash [32]byte, status PaymentStatus) error { + return db.Batch(func(tx *bolt.Tx) error { + return UpdatePaymentStatusTx(tx, paymentHash, status) + }) +} + +// UpdatePaymentStatusTx is a helper method that sets the payment status for +// outgoing/finished payments in the local database. This method accepts a +// boltdb transaction such that the operation can be composed into other +// database transactions. +func UpdatePaymentStatusTx(tx *bolt.Tx, + paymentHash [32]byte, status PaymentStatus) error { + + paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) + if err != nil { + return err + } + + return paymentStatuses.Put(paymentHash[:], status.Bytes()) +} + +// FetchPaymentStatus returns the payment status for outgoing payment. +// If status of the payment isn't found, it will default to "StatusGrounded". +func (db *DB) FetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { + var paymentStatus = StatusGrounded + err := db.View(func(tx *bolt.Tx) error { + var err error + paymentStatus, err = FetchPaymentStatusTx(tx, paymentHash) + return err + }) + if err != nil { + return StatusGrounded, err + } + + return paymentStatus, nil +} + +// FetchPaymentStatusTx is a helper method that returns the payment status for +// outgoing payment. If status of the payment isn't found, it will default to +// "StatusGrounded". It accepts the boltdb transactions such that this method +// can be composed into other atomic operations. +func FetchPaymentStatusTx(tx *bolt.Tx, paymentHash [32]byte) (PaymentStatus, error) { + // The default status for all payments that aren't recorded in database. + var paymentStatus = StatusGrounded + + bucket := tx.Bucket(paymentStatusBucket) + if bucket == nil { + return paymentStatus, nil + } + + paymentStatusBytes := bucket.Get(paymentHash[:]) + if paymentStatusBytes == nil { + return paymentStatus, nil + } + + paymentStatus.FromBytes(paymentStatusBytes) + + return paymentStatus, nil +} + func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error { var scratch [8]byte diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 450b4acff..d13e039d0 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -40,6 +40,14 @@ func makeFakePayment() *OutgoingPayment { return fakePayment } +func makeFakePaymentHash() [32]byte { + var paymentHash [32]byte + rBytes, _ := randomBytes(0, 32) + copy(paymentHash[:], rBytes) + + return paymentHash +} + // randomBytes creates random []byte with length in range [minLen, maxLen) func randomBytes(minLen, maxLen int) ([]byte, error) { randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen)) @@ -195,3 +203,51 @@ func TestOutgoingPaymentWorkflow(t *testing.T) { len(paymentsAfterDeletion), 0) } } + +func TestPaymentStatusWorkflow(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + testCases := []struct { + paymentHash [32]byte + status PaymentStatus + }{ + { + paymentHash: makeFakePaymentHash(), + status: StatusGrounded, + }, + { + paymentHash: makeFakePaymentHash(), + status: StatusInFlight, + }, + { + paymentHash: makeFakePaymentHash(), + status: StatusCompleted, + }, + } + + for _, testCase := range testCases { + err := db.UpdatePaymentStatus(testCase.paymentHash, testCase.status) + if err != nil { + t.Fatalf("unable to put payment in DB: %v", err) + } + + status, err := db.FetchPaymentStatus(testCase.paymentHash) + if err != nil { + t.Fatalf("unable to fetch payments from DB: %v", err) + } + + if status != testCase.status { + t.Fatalf("Wrong payments status after reading from DB."+ + "Got %v, want %v", + spew.Sdump(status), + spew.Sdump(testCase.status), + ) + } + } +} diff --git a/htlcswitch/control_tower.go b/htlcswitch/control_tower.go new file mode 100644 index 000000000..47b5bd3d9 --- /dev/null +++ b/htlcswitch/control_tower.go @@ -0,0 +1,245 @@ +package htlcswitch + +import ( + "errors" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // ErrAlreadyPaid signals we have already paid this payment hash. + ErrAlreadyPaid = errors.New("invoice is already paid") + + // ErrPaymentInFlight signals that payment for this payment hash is + // already "in flight" on the network. + ErrPaymentInFlight = errors.New("payment is in transition") + + // ErrPaymentNotInitiated is returned if payment wasn't initiated in + // switch. + ErrPaymentNotInitiated = errors.New("payment isn't initiated") + + // ErrPaymentAlreadyCompleted is returned in the event we attempt to + // recomplete a completed payment. + ErrPaymentAlreadyCompleted = errors.New("payment is already completed") + + // ErrUnknownPaymentStatus is returned when we do not recognize the + // existing state of a payment. + ErrUnknownPaymentStatus = errors.New("unknown payment status") +) + +// ControlTower tracks all outgoing payments made by the switch, whose primary +// purpose is to prevent duplicate payments to the same payment hash. In +// production, a persistent implementation is preferred so that tracking can +// survive across restarts. Payments are transition through various payment +// states, and the ControlTower interface provides access to driving the state +// transitions. +type ControlTower interface { + // ClearForTakeoff atomically checks that no inflight or completed + // payments exist for this payment hash. If none are found, this method + // atomically transitions the status for this payment hash as InFlight. + ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error + + // Success transitions an InFlight payment into a Completed payment. + // After invoking this method, ClearForTakeoff should always return an + // error to prevent us from making duplicate payments to the same + // payment hash. + Success(paymentHash [32]byte) error + + // Fail transitions an InFlight payment into a Grounded Payment. After + // invoking this method, ClearForTakeoff should return nil on its next + // call for this payment hash, allowing the switch to make a subsequent + // payment. + Fail(paymentHash [32]byte) error +} + +// paymentControl is persistent implementation of ControlTower to restrict +// double payment sending. +type paymentControl struct { + strict bool + + db *channeldb.DB +} + +// NewPaymentControl creates a new instance of the paymentControl. The strict +// flag indicates whether the controller should require "strict" state +// transitions, which would be otherwise intolerant to older databases that may +// already have duplicate payments to the same payment hash. It should be +// enabled only after sufficient checks have been made to ensure the db does not +// contain such payments. In the meantime, non-strict mode enforces a superset +// of the state transitions that prevent additional payments to a given payment +// hash from being added. +func NewPaymentControl(strict bool, db *channeldb.DB) ControlTower { + return &paymentControl{ + strict: strict, + db: db, + } +} + +// ClearForTakeoff checks that we don't already have an InFlight or Completed +// payment identified by the same payment hash. +func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error { + var takeoffErr error + err := p.db.Batch(func(tx *bolt.Tx) error { + // Retrieve current status of payment from local database. + paymentStatus, err := channeldb.FetchPaymentStatusTx( + tx, htlc.PaymentHash, + ) + if err != nil { + return err + } + + // Reset the takeoff error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + takeoffErr = nil + + switch paymentStatus { + + case channeldb.StatusGrounded: + // It is safe to reattempt a payment if we know that we + // haven't left one in flight. Since this one is + // grounded, Transition the payment status to InFlight + // to prevent others. + return channeldb.UpdatePaymentStatusTx( + tx, htlc.PaymentHash, channeldb.StatusInFlight, + ) + + case channeldb.StatusInFlight: + // We already have an InFlight payment on the network. We will + // disallow any more payment until a response is received. + takeoffErr = ErrPaymentInFlight + + case channeldb.StatusCompleted: + // We've already completed a payment to this payment hash, + // forbid the switch from sending another. + takeoffErr = ErrAlreadyPaid + + default: + takeoffErr = ErrUnknownPaymentStatus + } + + return nil + }) + if err != nil { + return err + } + + return takeoffErr +} + +// Success transitions an InFlight payment to Completed, otherwise it returns an +// error. After calling Success, ClearForTakeoff should prevent any further +// attempts for the same payment hash. +func (p *paymentControl) Success(paymentHash [32]byte) error { + var updateErr error + err := p.db.Batch(func(tx *bolt.Tx) error { + paymentStatus, err := channeldb.FetchPaymentStatusTx( + tx, paymentHash, + ) + if err != nil { + return err + } + + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + switch { + + case paymentStatus == channeldb.StatusGrounded && p.strict: + // Our records show the payment as still being grounded, + // meaning it never should have left the switch. + updateErr = ErrPaymentNotInitiated + + case paymentStatus == channeldb.StatusGrounded && !p.strict: + // Though our records show the payment as still being + // grounded, meaning it never should have left the + // switch, we permit this transition in non-strict mode + // to handle inconsistent db states. + fallthrough + + case paymentStatus == channeldb.StatusInFlight: + // A successful response was received for an InFlight + // payment, mark it as completed to prevent sending to + // this payment hash again. + return channeldb.UpdatePaymentStatusTx( + tx, paymentHash, channeldb.StatusCompleted, + ) + + case paymentStatus == channeldb.StatusCompleted: + // The payment was completed previously, alert the + // caller that this may be a duplicate call. + updateErr = ErrPaymentAlreadyCompleted + + default: + updateErr = ErrUnknownPaymentStatus + } + + return nil + }) + if err != nil { + return err + } + + return updateErr +} + +// Fail transitions an InFlight payment to Grounded, otherwise it returns an +// error. After calling Fail, ClearForTakeoff should fail any further attempts +// for the same payment hash. +func (p *paymentControl) Fail(paymentHash [32]byte) error { + var updateErr error + err := p.db.Batch(func(tx *bolt.Tx) error { + paymentStatus, err := channeldb.FetchPaymentStatusTx( + tx, paymentHash, + ) + if err != nil { + return err + } + + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + switch { + + case paymentStatus == channeldb.StatusGrounded && p.strict: + // Our records show the payment as still being grounded, + // meaning it never should have left the switch. + updateErr = ErrPaymentNotInitiated + + case paymentStatus == channeldb.StatusGrounded && !p.strict: + // Though our records show the payment as still being + // grounded, meaning it never should have left the + // switch, we permit this transition in non-strict mode + // to handle inconsistent db states. + fallthrough + + case paymentStatus == channeldb.StatusInFlight: + // A failed response was received for an InFlight + // payment, mark it as Grounded again to allow + // subsequent attempts. + return channeldb.UpdatePaymentStatusTx( + tx, paymentHash, channeldb.StatusGrounded, + ) + + case paymentStatus == channeldb.StatusCompleted: + // The payment was completed previously, and we are now + // reporting that it has failed. Leave the status as + // completed, but alert the user that something is + // wrong. + updateErr = ErrPaymentAlreadyCompleted + + default: + updateErr = ErrUnknownPaymentStatus + } + + return nil + }) + if err != nil { + return err + } + + return updateErr +} diff --git a/htlcswitch/control_tower_test.go b/htlcswitch/control_tower_test.go new file mode 100644 index 000000000..2728e3622 --- /dev/null +++ b/htlcswitch/control_tower_test.go @@ -0,0 +1,351 @@ +package htlcswitch + +import ( + "fmt" + "testing" + + "github.com/btcsuite/fastsha256" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" +) + +func genHtlc() (*lnwire.UpdateAddHTLC, error) { + preimage, err := genPreimage() + if err != nil { + return nil, fmt.Errorf("unable to generate preimage: %v", err) + } + + rhash := fastsha256.Sum256(preimage[:]) + htlc := &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + } + + return htlc, nil +} + +type paymentControlTestCase func(*testing.T, bool) + +var paymentControlTests = []struct { + name string + strict bool + testcase paymentControlTestCase +}{ + { + name: "fail-strict", + strict: true, + testcase: testPaymentControlSwitchFail, + }, + { + name: "double-send-strict", + strict: true, + testcase: testPaymentControlSwitchDoubleSend, + }, + { + name: "double-pay-strict", + strict: true, + testcase: testPaymentControlSwitchDoublePay, + }, + { + name: "fail-not-strict", + strict: false, + testcase: testPaymentControlSwitchFail, + }, + { + name: "double-send-not-strict", + strict: false, + testcase: testPaymentControlSwitchDoubleSend, + }, + { + name: "double-pay-not-strict", + strict: false, + testcase: testPaymentControlSwitchDoublePay, + }, +} + +// TestPaymentControls runs a set of common tests against both the strict and +// non-strict payment control instances. This ensures that the two both behave +// identically when making the expected state-transitions of the stricter +// implementation. Behavioral differences in the strict and non-strict +// implementations are tested separately. +func TestPaymentControls(t *testing.T) { + for _, test := range paymentControlTests { + t.Run(test.name, func(t *testing.T) { + test.testcase(t, test.strict) + }) + } +} + +// testPaymentControlSwitchFail checks that payment status returns to Grounded +// status after failing, and that ClearForTakeoff allows another HTLC for the +// same payment hash. +func testPaymentControlSwitchFail(t *testing.T, strict bool) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(strict, db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + if err := pControl.ClearForTakeoff(htlc); err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) + + // Fail the payment, which should moved it to Grounded. + if err := pControl.Fail(htlc.PaymentHash); err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Verify the status is indeed Grounded. + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) + + // Sends the htlc again, which should succeed since the prior payment + // failed. + if err := pControl.ClearForTakeoff(htlc); err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) + + // Verifies that status was changed to StatusCompleted. + if err := pControl.Success(htlc.PaymentHash); err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) + + // Attempt a final payment, which should now fail since the prior + // payment succeed. + if err := pControl.ClearForTakeoff(htlc); err != ErrAlreadyPaid { + t.Fatalf("unable to send htlc message: %v", err) + } +} + +// testPaymentControlSwitchDoubleSend checks the ability of payment control to +// prevent double sending of htlc message, when message is in StatusInFlight. +func testPaymentControlSwitchDoubleSend(t *testing.T, strict bool) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(strict, db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate base status and move it to + // StatusInFlight and verifies that it was changed. + if err := pControl.ClearForTakeoff(htlc); err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) + + // Try to initiate double sending of htlc message with the same + // payment hash, should result in error indicating that payment has + // already been sent. + if err := pControl.ClearForTakeoff(htlc); err != ErrPaymentInFlight { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") + } +} + +// TestPaymentControlSwitchDoublePay checks the ability of payment control to +// prevent double payment. +func testPaymentControlSwitchDoublePay(t *testing.T, strict bool) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(strict, db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + if err := pControl.ClearForTakeoff(htlc); err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + // Verify that payment is InFlight. + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) + + // Move payment to completed status, second payment should return error. + if err := pControl.Success(htlc.PaymentHash); err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + // Verify that payment is Completed. + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) + + if err := pControl.ClearForTakeoff(htlc); err != ErrAlreadyPaid { + t.Fatalf("payment control wrong behaviour:" + + " double payment must trigger ErrAlreadyPaid") + } +} + +// TestPaymentControlNonStrictSuccessesWithoutInFlight checks that a non-strict +// payment control will allow calls to Success when no payment is in flight. This +// is necessary to gracefully handle the case in which the switch already sent +// out a payment for a particular payment hash in a prior db version that didn't +// have payment statuses. +func TestPaymentControlNonStrictSuccessesWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(false, db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + if err := pControl.Success(htlc.PaymentHash); err != nil { + t.Fatalf("unable to mark payment hash success: %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) + + err = pControl.Success(htlc.PaymentHash) + if err != ErrPaymentAlreadyCompleted { + t.Fatalf("unable to remark payment hash failed: %v", err) + } +} + +// TestPaymentControlNonStrictFailsWithoutInFlight checks that a non-strict +// payment control will allow calls to Fail when no payment is in flight. This +// is necessary to gracefully handle the case in which the switch already sent +// out a payment for a particular payment hash in a prior db version that didn't +// have payment statuses. +func TestPaymentControlNonStrictFailsWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(false, db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + if err := pControl.Fail(htlc.PaymentHash); err != nil { + t.Fatalf("unable to mark payment hash failed: %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) + + err = pControl.Fail(htlc.PaymentHash) + if err != nil { + t.Fatalf("unable to remark payment hash failed: %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) + + err = pControl.Success(htlc.PaymentHash) + if err != nil { + t.Fatalf("unable to remark payment hash success: %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) + + err = pControl.Fail(htlc.PaymentHash) + if err != ErrPaymentAlreadyCompleted { + t.Fatalf("unable to remark payment hash failed: %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) +} + +// TestPaymentControlStrictSuccessesWithoutInFlight checks that a strict payment +// control will disallow calls to Success when no payment is in flight. +func TestPaymentControlStrictSuccessesWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(true, db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + err = pControl.Success(htlc.PaymentHash) + if err != ErrPaymentNotInitiated { + t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) +} + +// TestPaymentControlStrictFailsWithoutInFlight checks that a strict payment +// control will disallow calls to Fail when no payment is in flight. +func TestPaymentControlStrictFailsWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(true, db) + + htlc, err := genHtlc() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + err = pControl.Fail(htlc.PaymentHash) + if err != ErrPaymentNotInitiated { + t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) + } + + assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) +} + +func assertPaymentStatus(t *testing.T, db *channeldb.DB, + hash [32]byte, expStatus channeldb.PaymentStatus) { + + t.Helper() + + pStatus, err := db.FetchPaymentStatus(hash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if pStatus != expStatus { + t.Fatalf("payment status mismatch: expected %v, got %v", + expStatus, pStatus) + } +} diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index a5f417be4..b48587e6d 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -3755,8 +3755,8 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { n.firstBobChannelLink.ShortChanID(), htlc, newMockDeobfuscator(), ) - if err != nil { - t.Fatalf("error shouldn't have been received got: %v", err) + if err != ErrAlreadyPaid { + t.Fatalf("ErrAlreadyPaid should have been received got: %v", err) } } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 9df919933..e2539fac3 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -124,14 +124,25 @@ type mockServer struct { var _ lnpeer.Peer = (*mockServer)(nil) -func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { - if db == nil { - tempPath, err := ioutil.TempDir("", "switchdb") - if err != nil { - return nil, err - } +func initDB() (*channeldb.DB, error) { + tempPath, err := ioutil.TempDir("", "switchdb") + if err != nil { + return nil, err + } - db, err = channeldb.Open(tempPath) + db, err := channeldb.Open(tempPath) + if err != nil { + return nil, err + } + + return db, err +} + +func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { + var err error + + if db == nil { + db, err = initDB() if err != nil { return nil, err } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 8f94a9a04..2e5248a41 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -10,11 +10,10 @@ import ( "time" "github.com/btcsuite/btcd/btcec" - "github.com/coreos/bbolt" - "github.com/davecgh/go-spew/spew" - "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" @@ -47,6 +46,11 @@ var ( // txn. ErrIncompleteForward = errors.New("incomplete forward detected") + // ErrUnknownErrorDecryptor signals that we were unable to locate the + // error decryptor for this payment. This is likely due to restarting + // the daemon. + ErrUnknownErrorDecryptor = errors.New("unknown error decryptor") + // ErrSwitchExiting signaled when the switch has received a shutdown // request. ErrSwitchExiting = errors.New("htlcswitch shutting down") @@ -64,7 +68,6 @@ type pendingPayment struct { amount lnwire.MilliSatoshi preimage chan [sha256.Size]byte - response chan *htlcPacket err chan error // deobfuscator is a serializable entity which is used if we received @@ -204,6 +207,9 @@ type Switch struct { paymentSequencer Sequencer + // control provides verification of sending htlc mesages + control ControlTower + // circuits is storage for payment circuits which are used to // forward the settle/fail htlc updates back to the add htlc initiator. circuits CircuitMap @@ -289,6 +295,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { cfg: &cfg, circuits: circuitMap, paymentSequencer: sequencer, + control: NewPaymentControl(false, cfg.DB), linkIndex: make(map[lnwire.ChannelID]ChannelLink), mailOrchestrator: newMailOrchestrator(), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), @@ -344,11 +351,17 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, htlc *lnwire.UpdateAddHTLC, deobfuscator ErrorDecrypter) ([sha256.Size]byte, error) { + // Before sending, double check that we don't already have 1) an + // in-flight payment to this payment hash, or 2) a complete payment for + // the same hash. + if err := s.control.ClearForTakeoff(htlc); err != nil { + return zeroPreimage, err + } + // Create payment and add to the map of payment in order later to be // able to retrieve it and return response to the user. payment := &pendingPayment{ err: make(chan error, 1), - response: make(chan *htlcPacket, 1), preimage: make(chan [sha256.Size]byte, 1), paymentHash: htlc.PaymentHash, amount: htlc.Amount, @@ -376,13 +389,16 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, if err := s.forward(packet); err != nil { s.removePendingPayment(paymentID) + if err := s.control.Fail(htlc.PaymentHash); err != nil { + return zeroPreimage, err + } + return zeroPreimage, err } // Returns channels so that other subsystem might wait/skip the // waiting of handling of payment. var preimage [sha256.Size]byte - var response *htlcPacket select { case e := <-payment.err: @@ -391,13 +407,6 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, return zeroPreimage, ErrSwitchExiting } - select { - case pkt := <-payment.response: - response = pkt - case <-s.quit: - return zeroPreimage, ErrSwitchExiting - } - select { case p := <-payment.preimage: preimage = p @@ -405,24 +414,6 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, return zeroPreimage, ErrSwitchExiting } - // Remove circuit since we are about to complete an add/fail of this - // HTLC. - if teardownErr := s.teardownCircuit(response); teardownErr != nil { - log.Warnf("unable to teardown circuit %s: %v", - response.inKey(), teardownErr) - return preimage, err - } - - // Finally, if this response is contained in a forwarding package, ack - // the settle/fail so that we don't continue to retransmit the HTLC - // internally. - if response.destRef != nil { - if ackErr := s.ackSettleFail(*response.destRef); ackErr != nil { - log.Warnf("unable to ack settle/fail reference: %s: %v", - *response.destRef, ackErr) - } - } - return preimage, err } @@ -770,20 +761,10 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error, // Alice Bob Carol // func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { - // Pending payments use a special interpretation of the incomingChanID and - // incomingHTLCID fields on packet where the channel ID is blank and the - // HTLC ID is the payment ID. The switch basically views the users of the - // node as a special channel that also offers a sequence of HTLCs. - payment, err := s.findPayment(pkt.incomingHTLCID) - if err != nil { - return err - } - - switch htlc := pkt.htlc.(type) { - // User have created the htlc update therefore we should find the // appropriate channel link and send the payment over this link. - case *lnwire.UpdateAddHTLC: + if htlc, ok := pkt.htlc.(*lnwire.UpdateAddHTLC); ok { + // Try to find links by node destination. s.indexMtx.RLock() link, err := s.getLinkByShortID(pkt.outgoingChanID) s.indexMtx.RUnlock() @@ -827,31 +808,113 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { } return link.HandleSwitchPacket(pkt) - - // We've just received a settle update which means we can finalize the - // user payment and return successful response. - case *lnwire.UpdateFulfillHTLC: - // Notify the user that his payment was successfully proceed. - payment.err <- nil - payment.response <- pkt - payment.preimage <- htlc.PaymentPreimage - s.removePendingPayment(pkt.incomingHTLCID) - - // We've just received a fail update which means we can finalize the - // user payment and return fail response. - case *lnwire.UpdateFailHTLC: - payment.err <- s.parseFailedPayment(payment, pkt, htlc) - payment.response <- pkt - payment.preimage <- zeroPreimage - s.removePendingPayment(pkt.incomingHTLCID) - - default: - return errors.New("wrong update type") } + s.wg.Add(1) + go s.handleLocalResponse(pkt) + return nil } +// handleLocalResponse processes a Settle or Fail responding to a +// locally-initiated payment. This is handled asynchronously to avoid blocking +// the main event loop within the switch, as these operations can require +// multiple db transactions. The guarantees of the circuit map are stringent +// enough such that we are able to tolerate reordering of these operations +// without side effects. The primary operations handled are: +// 1. Ack settle/fail references, to avoid resending this response internally +// 2. Teardown the closing circuit in the circuit map +// 3. Transition the payment status to grounded or completed. +// 4. Respond to an in-mem pending payment, if it is found. +// +// NOTE: This method MUST be spawned as a goroutine. +func (s *Switch) handleLocalResponse(pkt *htlcPacket) { + defer s.wg.Done() + + // First, we'll clean up any fwdpkg references, circuit entries, and + // mark in our db that the payment for this payment hash has either + // succeeded or failed. + // + // If this response is contained in a forwarding package, we'll start by + // acking the settle/fail so that we don't continue to retransmit the + // HTLC internally. + if pkt.destRef != nil { + if err := s.ackSettleFail(*pkt.destRef); err != nil { + log.Warnf("Unable to ack settle/fail reference: %s: %v", + *pkt.destRef, err) + return + } + } + + // Next, we'll remove the circuit since we are about to complete an + // fulfill/fail of this HTLC. Since we've already removed the + // settle/fail fwdpkg reference, the response from the peer cannot be + // replayed internally if this step fails. If this happens, this logic + // will be executed when a provided resolution message comes through. + // This can only happen if the circuit is still open, which is why this + // ordering is chosen. + if err := s.teardownCircuit(pkt); err != nil { + log.Warnf("Unable to teardown circuit %s: %v", + pkt.inKey(), err) + return + } + + // Locate the pending payment to notify the application that this + // payment has failed. If one is not found, it likely means the daemon + // has been restarted since sending the payment. + payment := s.findPayment(pkt.incomingHTLCID) + + var ( + preimage [32]byte + paymentErr error + ) + + switch htlc := pkt.htlc.(type) { + + // We've received a settle update which means we can finalize the user + // payment and return successful response. + case *lnwire.UpdateFulfillHTLC: + // Persistently mark that a payment to this payment hash + // succeeded. This will prevent us from ever making another + // payment to this hash. + err := s.control.Success(pkt.circuit.PaymentHash) + if err != nil && err != ErrPaymentAlreadyCompleted { + log.Warnf("Unable to mark completed payment %x: %v", + pkt.circuit.PaymentHash, err) + return + } + + preimage = htlc.PaymentPreimage + + // We've received a fail update which means we can finalize the user + // payment and return fail response. + case *lnwire.UpdateFailHTLC: + // Persistently mark that a payment to this payment hash failed. + // This will permit us to make another attempt at a successful + // payment. + err := s.control.Fail(pkt.circuit.PaymentHash) + if err != nil && err != ErrPaymentAlreadyCompleted { + log.Warnf("Unable to ground payment %x: %v", + pkt.circuit.PaymentHash, err) + return + } + + paymentErr = s.parseFailedPayment(payment, pkt, htlc) + + default: + log.Warnf("Received unknown response type: %T", pkt.htlc) + return + } + + // Deliver the payment error and preimage to the application, if it is + // waiting for a response. + if payment != nil { + payment.err <- paymentErr + payment.preimage <- preimage + s.removePendingPayment(pkt.incomingHTLCID) + } +} + // parseFailedPayment determines the appropriate failure message to return to // a user initiated payment. The three cases handled are: // 1) A local failure, which should already plaintext. @@ -874,7 +937,8 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, failureMsg, err := lnwire.DecodeFailure(r, 0) if err != nil { userErr = fmt.Sprintf("unable to decode onion failure, "+ - "htlc with hash(%x): %v", payment.paymentHash[:], err) + "htlc with hash(%x): %v", + pkt.circuit.PaymentHash[:], err) log.Error(userErr) // As this didn't even clear the link, we don't need to @@ -901,6 +965,18 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, FailureMessage: lnwire.FailPermanentChannelFailure{}, } + // If the provided payment is nil, we have discarded the error decryptor + // due to a restart. We'll return a fixed error and signal a temporary + // channel failure to the router. + case payment == nil: + userErr := fmt.Sprintf("error decryptor for payment " + + "could not be located, likely due to restart") + failure = &ForwardingError{ + ErrorSource: s.cfg.SelfKey, + ExtraMsg: userErr, + FailureMessage: lnwire.NewTemporaryChannelFailure(nil), + } + // A regular multi-hop payment error that we'll need to // decrypt. default: @@ -909,8 +985,9 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, // error. If we're unable to then we'll bail early. failure, err = payment.deobfuscator.DecryptError(htlc.Reason) if err != nil { - userErr := fmt.Sprintf("unable to de-obfuscate onion failure, "+ - "htlc with hash(%x): %v", payment.paymentHash[:], err) + userErr := fmt.Sprintf("unable to de-obfuscate onion "+ + "failure, htlc with hash(%x): %v", + pkt.circuit.PaymentHash[:], err) log.Error(userErr) failure = &ForwardingError{ ErrorSource: s.cfg.SelfKey, @@ -2042,30 +2119,26 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) { // removePendingPayment is the helper function which removes the pending user // payment. -func (s *Switch) removePendingPayment(paymentID uint64) error { +func (s *Switch) removePendingPayment(paymentID uint64) { s.pendingMutex.Lock() defer s.pendingMutex.Unlock() - if _, ok := s.pendingPayments[paymentID]; !ok { - return fmt.Errorf("Cannot find pending payment with ID %d", - paymentID) - } - delete(s.pendingPayments, paymentID) - return nil } // findPayment is the helper function which find the payment. -func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) { +func (s *Switch) findPayment(paymentID uint64) *pendingPayment { s.pendingMutex.RLock() defer s.pendingMutex.RUnlock() payment, ok := s.pendingPayments[paymentID] if !ok { - return nil, fmt.Errorf("Cannot find pending payment with ID %d", + log.Errorf("Cannot find pending payment with ID %d", paymentID) + return nil } - return payment, nil + + return payment } // CircuitModifier returns a reference to subset of the interfaces provided by @@ -2077,6 +2150,9 @@ func (s *Switch) CircuitModifier() CircuitModifier { // numPendingPayments is helper function which returns the overall number of // pending user payments. func (s *Switch) numPendingPayments() int { + s.pendingMutex.RLock() + defer s.pendingMutex.RUnlock() + return len(s.pendingPayments) } diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 7610a797f..6dd8980fe 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1678,7 +1678,9 @@ func TestSwitchSendPayment(t *testing.T) { } case err := <-errChan: - t.Fatalf("unable to send payment: %v", err) + if err != ErrPaymentInFlight { + t.Fatalf("unable to send payment: %v", err) + } case <-time.After(time.Second): t.Fatal("request was not propagated to destination") } @@ -1695,11 +1697,11 @@ func TestSwitchSendPayment(t *testing.T) { t.Fatal("request was not propagated to destination") } - if s.numPendingPayments() != 2 { + if s.numPendingPayments() != 1 { t.Fatal("wrong amount of pending payments") } - if s.circuits.NumOpen() != 2 { + if s.circuits.NumOpen() != 1 { t.Fatal("wrong amount of circuits") } @@ -1735,29 +1737,6 @@ func TestSwitchSendPayment(t *testing.T) { t.Fatal("err wasn't received") } - packet = &htlcPacket{ - outgoingChanID: aliceChannelLink.ShortChanID(), - outgoingHTLCID: 1, - htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, - }, - } - - // Send second failure response and check that user were able to - // receive the error. - if err := s.forward(packet); err != nil { - t.Fatalf("can't forward htlc packet: %v", err) - } - - select { - case err := <-errChan: - if err.Error() != errors.New(lnwire.CodeIncorrectPaymentAmount).Error() { - t.Fatal("err wasn't received") - } - case <-time.After(time.Second): - t.Fatal("err wasn't received") - } - if s.numPendingPayments() != 0 { t.Fatal("wrong amount of pending payments") } diff --git a/lnd_test.go b/lnd_test.go index 41b637636..7e4484749 100644 --- a/lnd_test.go +++ b/lnd_test.go @@ -450,6 +450,17 @@ func completePaymentRequests(ctx context.Context, client lnrpc.LightningClient, return nil } +// makeFakePayHash creates random pre image hash +func makeFakePayHash(t *harnessTest) []byte { + randBuf := make([]byte, 32) + + if _, err := rand.Read(randBuf); err != nil { + t.Fatalf("internal error, cannot generate random string: %v", err) + } + + return randBuf +} + const ( AddrTypeWitnessPubkeyHash = lnrpc.NewAddressRequest_WITNESS_PUBKEY_HASH AddrTypeNestedPubkeyHash = lnrpc.NewAddressRequest_NESTED_PUBKEY_HASH @@ -1822,13 +1833,13 @@ func testChannelForceClosure(net *lntest.NetworkHarness, t *harnessTest) { if err != nil { t.Fatalf("unable to create payment stream for alice: %v", err) } + carolPubKey := carol.PubKey[:] - payHash := bytes.Repeat([]byte{2}, 32) for i := 0; i < numInvoices; i++ { err = alicePayStream.Send(&lnrpc.SendRequest{ Dest: carolPubKey, Amt: int64(paymentAmt), - PaymentHash: payHash, + PaymentHash: makeFakePayHash(t), FinalCltvDelta: defaultBitcoinTimeLockDelta, }) if err != nil { @@ -3945,9 +3956,16 @@ func testPrivateChannels(net *lntest.NetworkHarness, t *harnessTest) { const paymentAmt = 70000 payReqs := make([]string, numPayments) for i := 0; i < numPayments; i++ { + preimage := make([]byte, 32) + _, err := rand.Read(preimage) + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } + invoice := &lnrpc.Invoice{ - Memo: "testing", - Value: paymentAmt, + Memo: "testing", + RPreimage: preimage, + Value: paymentAmt, } resp, err := net.Bob.AddInvoice(ctxb, invoice) if err != nil { @@ -4008,9 +4026,16 @@ func testPrivateChannels(net *lntest.NetworkHarness, t *harnessTest) { const paymentAmt60k = 60000 payReqs = make([]string, numPayments) for i := 0; i < numPayments; i++ { + preimage := make([]byte, 32) + _, err := rand.Read(preimage) + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } + invoice := &lnrpc.Invoice{ - Memo: "testing", - Value: paymentAmt60k, + Memo: "testing", + RPreimage: preimage, + Value: paymentAmt60k, } resp, err := carol.AddInvoice(ctxb, invoice) if err != nil { @@ -4496,10 +4521,9 @@ func testInvoiceSubscriptions(net *lntest.NetworkHarness, t *harnessTest) { // TODO(roasbeef): make global list of invoices for each node to re-use // and avoid collisions const paymentAmt = 1000 - preimage := bytes.Repeat([]byte{byte(90)}, 32) invoice := &lnrpc.Invoice{ Memo: "testing", - RPreimage: preimage, + RPreimage: makeFakePayHash(t), Value: paymentAmt, } invoiceResp, err := net.Bob.AddInvoice(ctxb, invoice) @@ -6727,7 +6751,7 @@ out: // stream on payment error. ctxt, _ = context.WithTimeout(ctxb, timeout) sendReq := &lnrpc.SendRequest{ - PaymentHashString: hex.EncodeToString(bytes.Repeat([]byte("Z"), 32)), + PaymentHashString: hex.EncodeToString(makeFakePayHash(t)), DestString: hex.EncodeToString(carol.PubKey[:]), Amt: payAmt, } @@ -6856,6 +6880,12 @@ out: "instead: %v", resp.PaymentError) } + // Generate new invoice to not pay same invoice twice. + carolInvoice, err = carol.AddInvoice(ctxb, invoiceReq) + if err != nil { + t.Fatalf("unable to generate carol invoice: %v", err) + } + // For our final test, we'll ensure that if a target link isn't // available for what ever reason then the payment fails accordingly. // @@ -7953,8 +7983,8 @@ func testMultiHopHtlcLocalTimeout(net *lntest.NetworkHarness, t *harnessTest) { // We'll create two random payment hashes unknown to carol, then send // each of them by manually specifying the HTLC details. carolPubKey := carol.PubKey[:] - dustPayHash := bytes.Repeat([]byte{1}, 32) - payHash := bytes.Repeat([]byte{2}, 32) + dustPayHash := makeFakePayHash(t) + payHash := makeFakePayHash(t) err = alicePayStream.Send(&lnrpc.SendRequest{ Dest: carolPubKey, Amt: int64(dustHtlcAmt), @@ -8412,7 +8442,7 @@ func testMultiHopLocalForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, // We'll now send a single HTLC across our multi-hop network. carolPubKey := carol.PubKey[:] - payHash := bytes.Repeat([]byte{2}, 32) + payHash := makeFakePayHash(t) err = alicePayStream.Send(&lnrpc.SendRequest{ Dest: carolPubKey, Amt: int64(htlcAmt), @@ -8669,7 +8699,7 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness, // We'll now send a single HTLC across our multi-hop network. carolPubKey := carol.PubKey[:] - payHash := bytes.Repeat([]byte{2}, 32) + payHash := makeFakePayHash(t) err = alicePayStream.Send(&lnrpc.SendRequest{ Dest: carolPubKey, Amt: int64(htlcAmt),