diff --git a/invoices/sql_migration.go b/invoices/sql_migration.go index 2bdb14048..47ee3e329 100644 --- a/invoices/sql_migration.go +++ b/invoices/sql_migration.go @@ -5,9 +5,13 @@ import ( "context" "encoding/binary" "fmt" + "strconv" + "time" + "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/sqldb" "github.com/lightningnetwork/lnd/sqldb/sqlc" ) @@ -126,3 +130,272 @@ func createInvoiceHashIndex(ctx context.Context, db kvdb.Backend, }) }, func() {}) } + +// toInsertMigratedInvoiceParams creates the parameters for inserting a migrated +// invoice into the SQL database. The parameters are derived from the original +// invoice insert parameters. +func toInsertMigratedInvoiceParams( + params sqlc.InsertInvoiceParams) sqlc.InsertMigratedInvoiceParams { + + return sqlc.InsertMigratedInvoiceParams{ + Hash: params.Hash, + Preimage: params.Preimage, + Memo: params.Memo, + AmountMsat: params.AmountMsat, + CltvDelta: params.CltvDelta, + Expiry: params.Expiry, + PaymentAddr: params.PaymentAddr, + PaymentRequest: params.PaymentRequest, + PaymentRequestHash: params.PaymentRequestHash, + State: params.State, + AmountPaidMsat: params.AmountPaidMsat, + IsAmp: params.IsAmp, + IsHodl: params.IsHodl, + IsKeysend: params.IsKeysend, + CreatedAt: params.CreatedAt, + } +} + +// MigrateSingleInvoice migrates a single invoice to the new SQL schema. Note +// that perfect equality between the old and new schemas is not achievable, as +// the invoice's add index cannot be mapped directly to its ID due to SQL’s +// auto-incrementing primary key. The ID returned from the insert will instead +// serve as the add index in the new schema. +func MigrateSingleInvoice(ctx context.Context, tx SQLInvoiceQueries, + invoice *Invoice, paymentHash lntypes.Hash) error { + + insertInvoiceParams, err := makeInsertInvoiceParams( + invoice, paymentHash, + ) + if err != nil { + return err + } + + // Convert the insert invoice parameters to the migrated invoice insert + // parameters. + insertMigratedInvoiceParams := toInsertMigratedInvoiceParams( + insertInvoiceParams, + ) + + // If the invoice is settled, we'll also set the timestamp and the index + // at which it was settled. + if invoice.State == ContractSettled { + if invoice.SettleIndex == 0 { + return fmt.Errorf("settled invoice %s missing settle "+ + "index", paymentHash) + } + + if invoice.SettleDate.IsZero() { + return fmt.Errorf("settled invoice %s missing settle "+ + "date", paymentHash) + } + + insertMigratedInvoiceParams.SettleIndex = sqldb.SQLInt64( + invoice.SettleIndex, + ) + insertMigratedInvoiceParams.SettledAt = sqldb.SQLTime( + invoice.SettleDate.UTC(), + ) + } + + // First we need to insert the invoice itself so we can use the "add + // index" which in this case is the auto incrementing primary key that + // is returned from the insert. + invoiceID, err := tx.InsertMigratedInvoice( + ctx, insertMigratedInvoiceParams, + ) + if err != nil { + return fmt.Errorf("unable to insert invoice: %w", err) + } + + // Insert the invoice's features. + for feature := range invoice.Terms.Features.Features() { + params := sqlc.InsertInvoiceFeatureParams{ + InvoiceID: invoiceID, + Feature: int32(feature), + } + + err := tx.InsertInvoiceFeature(ctx, params) + if err != nil { + return fmt.Errorf("unable to insert invoice "+ + "feature(%v): %w", feature, err) + } + } + + sqlHtlcIDs := make(map[models.CircuitKey]int64) + + // Now insert the HTLCs of the invoice. We'll also keep track of the SQL + // ID of each HTLC so we can use it when inserting the AMP sub invoices. + for circuitKey, htlc := range invoice.Htlcs { + htlcParams := sqlc.InsertInvoiceHTLCParams{ + HtlcID: int64(circuitKey.HtlcID), + ChanID: strconv.FormatUint( + circuitKey.ChanID.ToUint64(), 10, + ), + AmountMsat: int64(htlc.Amt), + AcceptHeight: int32(htlc.AcceptHeight), + AcceptTime: htlc.AcceptTime.UTC(), + ExpiryHeight: int32(htlc.Expiry), + State: int16(htlc.State), + InvoiceID: invoiceID, + } + + // Leave the MPP amount as NULL if the MPP total amount is zero. + if htlc.MppTotalAmt != 0 { + htlcParams.TotalMppMsat = sqldb.SQLInt64( + int64(htlc.MppTotalAmt), + ) + } + + // Leave the resolve time as NULL if the HTLC is not resolved. + if !htlc.ResolveTime.IsZero() { + htlcParams.ResolveTime = sqldb.SQLTime( + htlc.ResolveTime.UTC(), + ) + } + + sqlID, err := tx.InsertInvoiceHTLC(ctx, htlcParams) + if err != nil { + return fmt.Errorf("unable to insert invoice htlc: %w", + err) + } + + sqlHtlcIDs[circuitKey] = sqlID + + // Store custom records. + for key, value := range htlc.CustomRecords { + err = tx.InsertInvoiceHTLCCustomRecord( + ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{ + Key: int64(key), + Value: value, + HtlcID: sqlID, + }, + ) + if err != nil { + return err + } + } + } + + if !invoice.IsAMP() { + return nil + } + + for setID, ampState := range invoice.AMPState { + // Find the earliest HTLC of the AMP invoice, which will + // be used as the creation date of this sub invoice. + var createdAt time.Time + for circuitKey := range ampState.InvoiceKeys { + htlc := invoice.Htlcs[circuitKey] + if createdAt.IsZero() { + createdAt = htlc.AcceptTime.UTC() + continue + } + + if createdAt.After(htlc.AcceptTime) { + createdAt = htlc.AcceptTime.UTC() + } + } + + params := sqlc.InsertAMPSubInvoiceParams{ + SetID: setID[:], + State: int16(ampState.State), + CreatedAt: createdAt, + InvoiceID: invoiceID, + } + + if ampState.SettleIndex != 0 { + if ampState.SettleDate.IsZero() { + return fmt.Errorf("settled AMP sub invoice %x "+ + "missing settle date", setID) + } + + params.SettledAt = sqldb.SQLTime( + ampState.SettleDate.UTC(), + ) + + params.SettleIndex = sqldb.SQLInt64( + ampState.SettleIndex, + ) + } + + err := tx.InsertAMPSubInvoice(ctx, params) + if err != nil { + return fmt.Errorf("unable to insert AMP sub invoice: "+ + "%w", err) + } + + // Now we can add the AMP HTLCs to the database. + for circuitKey := range ampState.InvoiceKeys { + htlc := invoice.Htlcs[circuitKey] + rootShare := htlc.AMP.Record.RootShare() + + sqlHtlcID, ok := sqlHtlcIDs[circuitKey] + if !ok { + return fmt.Errorf("missing htlc for AMP htlc: "+ + "%v", circuitKey) + } + + params := sqlc.InsertAMPSubInvoiceHTLCParams{ + InvoiceID: invoiceID, + SetID: setID[:], + HtlcID: sqlHtlcID, + RootShare: rootShare[:], + ChildIndex: int64(htlc.AMP.Record.ChildIndex()), + Hash: htlc.AMP.Hash[:], + } + + if htlc.AMP.Preimage != nil { + params.Preimage = htlc.AMP.Preimage[:] + } + + err = tx.InsertAMPSubInvoiceHTLC(ctx, params) + if err != nil { + return fmt.Errorf("unable to insert AMP sub "+ + "invoice: %w", err) + } + } + } + + return nil +} + +// OverrideInvoiceTimeZone overrides the time zone of the invoice to the local +// time zone and chops off the nanosecond part for comparison. This is needed +// because KV database stores times as-is which as an unwanted side effect would +// fail migration due to time comparison expecting both the original and +// migrated invoices to be in the same local time zone and in microsecond +// precision. Note that PostgreSQL stores times in microsecond precision while +// SQLite can store times in nanosecond precision if using TEXT storage class. +func OverrideInvoiceTimeZone(invoice *Invoice) { + fixTime := func(t time.Time) time.Time { + return t.In(time.Local).Truncate(time.Microsecond) + } + + invoice.CreationDate = fixTime(invoice.CreationDate) + + if !invoice.SettleDate.IsZero() { + invoice.SettleDate = fixTime(invoice.SettleDate) + } + + if invoice.IsAMP() { + for setID, ampState := range invoice.AMPState { + if ampState.SettleDate.IsZero() { + continue + } + + ampState.SettleDate = fixTime(ampState.SettleDate) + invoice.AMPState[setID] = ampState + } + } + + for _, htlc := range invoice.Htlcs { + if !htlc.AcceptTime.IsZero() { + htlc.AcceptTime = fixTime(htlc.AcceptTime) + } + + if !htlc.ResolveTime.IsZero() { + htlc.ResolveTime = fixTime(htlc.ResolveTime) + } + } +} diff --git a/invoices/sql_migration_test.go b/invoices/sql_migration_test.go new file mode 100644 index 000000000..179097f48 --- /dev/null +++ b/invoices/sql_migration_test.go @@ -0,0 +1,421 @@ +package invoices + +import ( + "context" + crand "crypto/rand" + "database/sql" + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" + "pgregory.net/rapid" +) + +var ( + // testHtlcIDSequence is a global counter for generating unique HTLC + // IDs. + testHtlcIDSequence uint64 +) + +// randomString generates a random string of a given length using rapid. +func randomStringRapid(t *rapid.T, length int) string { + // Define the character set for the string. + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" //nolint:ll + + // Generate a string by selecting random characters from the charset. + runes := make([]rune, length) + for i := range runes { + // Draw a random index and use it to select a character from the + // charset. + index := rapid.IntRange(0, len(charset)-1).Draw(t, "charIndex") + runes[i] = rune(charset[index]) + } + + return string(runes) +} + +// randTimeBetween generates a random time between min and max. +func randTimeBetween(min, max time.Time) time.Time { + var timeZones = []*time.Location{ + time.UTC, + time.FixedZone("EST", -5*3600), + time.FixedZone("MST", -7*3600), + time.FixedZone("PST", -8*3600), + time.FixedZone("CEST", 2*3600), + } + + // Ensure max is after min + if max.Before(min) { + min, max = max, min + } + + // Calculate the range in nanoseconds + duration := max.Sub(min) + randDuration := time.Duration(rand.Int63n(duration.Nanoseconds())) + + // Generate the random time + randomTime := min.Add(randDuration) + + // Assign a random time zone + randomTimeZone := timeZones[rand.Intn(len(timeZones))] + + // Return the time in the random time zone + return randomTime.In(randomTimeZone) +} + +// randTime generates a random time between 2009 and 2140. +func randTime() time.Time { + min := time.Date(2009, 1, 3, 0, 0, 0, 0, time.UTC) + max := time.Date(2140, 1, 1, 0, 0, 0, 1000, time.UTC) + + return randTimeBetween(min, max) +} + +func randInvoiceTime(invoice *Invoice) time.Time { + return randTimeBetween( + invoice.CreationDate, + invoice.CreationDate.Add(invoice.Terms.Expiry), + ) +} + +// randHTLCRapid generates a random HTLC for an invoice using rapid to randomize +// its parameters. +func randHTLCRapid(t *rapid.T, invoice *Invoice, amt lnwire.MilliSatoshi) ( + models.CircuitKey, *InvoiceHTLC) { + + htlc := &InvoiceHTLC{ + Amt: amt, + AcceptHeight: rapid.Uint32Range(1, 999).Draw(t, "AcceptHeight"), + AcceptTime: randInvoiceTime(invoice), + Expiry: rapid.Uint32Range(1, 999).Draw(t, "Expiry"), + } + + // Set MPP total amount if MPP feature is enabled in the invoice. + if invoice.Terms.Features.HasFeature(lnwire.MPPRequired) { + htlc.MppTotalAmt = invoice.Terms.Value + } + + // Set the HTLC state and resolve time based on the invoice state. + switch invoice.State { + case ContractSettled: + htlc.State = HtlcStateSettled + htlc.ResolveTime = randInvoiceTime(invoice) + + case ContractCanceled: + htlc.State = HtlcStateCanceled + htlc.ResolveTime = randInvoiceTime(invoice) + + case ContractAccepted: + htlc.State = HtlcStateAccepted + } + + // Add randomized custom records to the HTLC. + htlc.CustomRecords = make(record.CustomSet) + numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords") + for i := 0; i < numRecords; i++ { + key := rapid.Uint64Range( + record.CustomTypeStart, 1000+record.CustomTypeStart, + ).Draw(t, "customRecordKey") + value := []byte(randomStringRapid(t, 10)) + htlc.CustomRecords[key] = value + } + + // Generate a unique HTLC ID and assign it to a channel ID. + htlcID := atomic.AddUint64(&testHtlcIDSequence, 1) + randChanID := lnwire.NewShortChanIDFromInt(htlcID % 5) + + circuitKey := models.CircuitKey{ + ChanID: randChanID, + HtlcID: htlcID, + } + + return circuitKey, htlc +} + +// generateInvoiceHTLCsRapid generates all HTLCs for an invoice, including AMP +// HTLCs if applicable, using rapid for randomization of HTLC count and +// distribution. +func generateInvoiceHTLCsRapid(t *rapid.T, invoice *Invoice) { + mpp := invoice.Terms.Features.HasFeature(lnwire.MPPRequired) + + // Use rapid to determine the number of HTLCs based on invoice state and + // MPP feature. + numHTLCs := 1 + if invoice.State == ContractOpen { + numHTLCs = 0 + } else if mpp { + numHTLCs = rapid.IntRange(1, 10).Draw(t, "numHTLCs") + } + + total := invoice.Terms.Value + + // Distribute the total amount across the HTLCs, adding any remainder to + // the last HTLC. + if numHTLCs > 0 { + amt := total / lnwire.MilliSatoshi(numHTLCs) + remainder := total - amt*lnwire.MilliSatoshi(numHTLCs) + + for i := 0; i < numHTLCs; i++ { + if i == numHTLCs-1 { + // Add remainder to the last HTLC. + amt += remainder + } + + // Generate an HTLC with a random circuit key and add it + // to the invoice. + circuitKey, htlc := randHTLCRapid(t, invoice, amt) + invoice.Htlcs[circuitKey] = htlc + } + } +} + +// generateAMPHtlcsRapid generates AMP HTLCs for an invoice using rapid to +// randomize various parameters of the HTLCs in the AMP set. +func generateAMPHtlcsRapid(t *rapid.T, invoice *Invoice) { + // Randomly determine the number of AMP sets (1 to 5). + numSetIDs := rapid.IntRange(1, 5).Draw(t, "numSetIDs") + settledIdx := uint64(1) + + for i := 0; i < numSetIDs; i++ { + var setID SetID + _, err := crand.Read(setID[:]) + require.NoError(t, err) + + // Determine the number of HTLCs in this set (1 to 5). + numHTLCs := rapid.IntRange(1, 5).Draw(t, "numHTLCs") + total := invoice.Terms.Value + invoiceKeys := make(map[CircuitKey]struct{}) + + // Calculate the amount per HTLC and account for remainder in + // the final HTLC. + amt := total / lnwire.MilliSatoshi(numHTLCs) + remainder := total - amt*lnwire.MilliSatoshi(numHTLCs) + + var htlcState HtlcState + for j := 0; j < numHTLCs; j++ { + if j == numHTLCs-1 { + amt += remainder + } + + // Generate HTLC with randomized parameters. + circuitKey, htlc := randHTLCRapid(t, invoice, amt) + htlcState = htlc.State + + var ( + rootShare, hash [32]byte + preimage lntypes.Preimage + ) + + // Randomize AMP data fields. + _, err := crand.Read(rootShare[:]) + require.NoError(t, err) + _, err = crand.Read(hash[:]) + require.NoError(t, err) + _, err = crand.Read(preimage[:]) + require.NoError(t, err) + + record := record.NewAMP(rootShare, setID, uint32(j)) + + htlc.AMP = &InvoiceHtlcAMPData{ + Record: *record, + Hash: hash, + Preimage: &preimage, + } + + invoice.Htlcs[circuitKey] = htlc + invoiceKeys[circuitKey] = struct{}{} + } + + ampState := InvoiceStateAMP{ + State: htlcState, + InvoiceKeys: invoiceKeys, + } + if htlcState == HtlcStateSettled { + ampState.SettleIndex = settledIdx + ampState.SettleDate = randInvoiceTime(invoice) + settledIdx++ + } + + // Set the total amount paid if the AMP set is not canceled. + if htlcState != HtlcStateCanceled { + ampState.AmtPaid = invoice.Terms.Value + } + + invoice.AMPState[setID] = ampState + } +} + +// TestMigrateSingleInvoiceRapid tests the migration of single invoices with +// random data variations using rapid. This test generates a random invoice +// configuration and ensures successful migration. +// +// NOTE: This test may need to be changed if the Invoice or any of the related +// types are modified. +func TestMigrateSingleInvoiceRapid(t *testing.T) { + // Create a shared Postgres instance for efficient testing. + pgFixture := sqldb.NewTestPgFixture( + t, sqldb.DefaultPostgresFixtureLifetime, + ) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + makeSQLDB := func(t *testing.T, sqlite bool) *SQLStore { + var db *sqldb.BaseDB + if sqlite { + db = sqldb.NewTestSqliteDB(t).BaseDB + } else { + db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB + } + + executor := sqldb.NewTransactionExecutor( + db, func(tx *sql.Tx) SQLInvoiceQueries { + return db.WithTx(tx) + }, + ) + + testClock := clock.NewTestClock(time.Unix(1, 0)) + + return NewSQLStore(executor, testClock) + } + + // Define property-based test using rapid. + rapid.Check(t, func(rt *rapid.T) { + // Randomized feature flags for MPP and AMP. + mpp := rapid.Bool().Draw(rt, "mpp") + amp := rapid.Bool().Draw(rt, "amp") + + for _, sqlite := range []bool{true, false} { + store := makeSQLDB(t, sqlite) + testMigrateSingleInvoiceRapid(rt, store, mpp, amp) + } + }) +} + +// testMigrateSingleInvoiceRapid is the primary function for the migration of a +// single invoice with random data in a rapid-based test setup. +func testMigrateSingleInvoiceRapid(t *rapid.T, store *SQLStore, mpp bool, + amp bool) { + + ctxb := context.Background() + invoices := make(map[lntypes.Hash]*Invoice) + + for i := 0; i < 100; i++ { + invoice := generateTestInvoiceRapid(t, mpp, amp) + var hash lntypes.Hash + _, err := crand.Read(hash[:]) + require.NoError(t, err) + + invoices[hash] = invoice + } + + var ops SQLInvoiceQueriesTxOptions + err := store.db.ExecTx(ctxb, &ops, func(tx SQLInvoiceQueries) error { + for hash, invoice := range invoices { + err := MigrateSingleInvoice(ctxb, tx, invoice, hash) + require.NoError(t, err) + } + + return nil + }, func() {}) + require.NoError(t, err) + + // Fetch and compare each migrated invoice from the store with the + // original. + for hash, invoice := range invoices { + sqlInvoice, err := store.LookupInvoice( + ctxb, InvoiceRefByHash(hash), + ) + require.NoError(t, err) + + invoice.AddIndex = sqlInvoice.AddIndex + + OverrideInvoiceTimeZone(invoice) + OverrideInvoiceTimeZone(&sqlInvoice) + + require.Equal(t, *invoice, sqlInvoice) + } +} + +// generateTestInvoiceRapid generates a random invoice with variations based on +// mpp and amp flags. +func generateTestInvoiceRapid(t *rapid.T, mpp bool, amp bool) *Invoice { + var preimage lntypes.Preimage + _, err := crand.Read(preimage[:]) + require.NoError(t, err) + + terms := ContractTerm{ + FinalCltvDelta: rapid.Int32Range(1, 1000).Draw( + t, "FinalCltvDelta", + ), + Expiry: time.Duration( + rapid.IntRange(1, 4444).Draw(t, "Expiry"), + ) * time.Minute, + PaymentPreimage: &preimage, + Value: lnwire.MilliSatoshi( + rapid.Int64Range(1, 9999999).Draw(t, "Value"), + ), + PaymentAddr: [32]byte{}, + Features: lnwire.EmptyFeatureVector(), + } + + if amp { + terms.Features.Set(lnwire.AMPRequired) + } else if mpp { + terms.Features.Set(lnwire.MPPRequired) + } + + created := randTime() + + const maxContractState = 3 + state := ContractState( + rapid.IntRange(0, maxContractState).Draw(t, "ContractState"), + ) + var ( + settled time.Time + settleIndex uint64 + ) + if state == ContractSettled { + settled = randTimeBetween(created, created.Add(terms.Expiry)) + settleIndex = rapid.Uint64Range(1, 999).Draw(t, "SettleIndex") + } + + invoice := &Invoice{ + Memo: []byte(randomStringRapid(t, 10)), + PaymentRequest: []byte( + randomStringRapid(t, MaxPaymentRequestSize), + ), + CreationDate: created, + SettleDate: settled, + Terms: terms, + AddIndex: 0, + SettleIndex: settleIndex, + State: state, + AMPState: make(map[SetID]InvoiceStateAMP), + HodlInvoice: rapid.Bool().Draw(t, "HodlInvoice"), + } + + invoice.Htlcs = make(map[models.CircuitKey]*InvoiceHTLC) + + if invoice.IsAMP() { + generateAMPHtlcsRapid(t, invoice) + } else { + generateInvoiceHTLCsRapid(t, invoice) + } + + for _, htlc := range invoice.Htlcs { + if htlc.State == HtlcStateSettled { + invoice.AmtPaid += htlc.Amt + } + } + + return invoice +} diff --git a/invoices/sql_store.go b/invoices/sql_store.go index 839b19a54..c9ffcc44c 100644 --- a/invoices/sql_store.go +++ b/invoices/sql_store.go @@ -32,6 +32,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64, error) + // TODO(bhandras): remove this once migrations have been separated out. + InsertMigratedInvoice(ctx context.Context, + arg sqlc.InsertMigratedInvoiceParams) (int64, error) + InsertInvoiceFeature(ctx context.Context, arg sqlc.InsertInvoiceFeatureParams) error