diff --git a/docs/release-notes/release-notes-0.18.3.md b/docs/release-notes/release-notes-0.18.3.md index 1739c2001..3017b1ffd 100644 --- a/docs/release-notes/release-notes-0.18.3.md +++ b/docs/release-notes/release-notes-0.18.3.md @@ -91,6 +91,12 @@ ## Testing ## Database + +* [Migrate](https://github.com/lightningnetwork/lnd/pull/8855) incorrectly + stored invoice expiry values. This migration only affects users of native SQL + invoice database. Invoices with incorrect expiry values will be updated to + 24-hour expiry, which is the default behavior in LND. + ## Code Health ## Tooling and Documentation diff --git a/go.mod b/go.mod index 2b38c7cc0..a1f55dcc0 100644 --- a/go.mod +++ b/go.mod @@ -112,6 +112,7 @@ require ( github.com/jackc/pgproto3/v2 v2.3.3 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgtype v1.14.0 // indirect + github.com/jackc/pgx/v5 v5.3.1 // indirect github.com/jackc/puddle v1.3.0 // indirect github.com/jonboulle/clockwork v0.2.2 // indirect github.com/json-iterator/go v1.1.11 // indirect diff --git a/go.sum b/go.sum index 97826101b..ba15e2c75 100644 --- a/go.sum +++ b/go.sum @@ -365,6 +365,8 @@ github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQ github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.18.2 h1:xVpYkNR5pk5bMCZGfClbO962UIqVABcAGt7ha1s/FeU= github.com/jackc/pgx/v4 v4.18.2/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= +github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= +github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= @@ -554,8 +556,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rogpeppe/fastuuid v1.2.0 h1:Ppwyp6VYCF1nvBTXL3trRso7mXMlRrw9ooo375wvi2s= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= diff --git a/invoices/invoices_test.go b/invoices/invoices_test.go index 4e30b0ac9..800261355 100644 --- a/invoices/invoices_test.go +++ b/invoices/invoices_test.go @@ -69,7 +69,7 @@ func randInvoice(value lnwire.MilliSatoshi) (*invpkg.Invoice, error) { i := &invpkg.Invoice{ CreationDate: testNow, Terms: invpkg.ContractTerm{ - Expiry: 4000, + Expiry: time.Duration(4000) * time.Second, PaymentPreimage: &pre, PaymentAddr: payAddr, Value: value, diff --git a/invoices/sql_store.go b/invoices/sql_store.go index 21254ddef..4b488715b 100644 --- a/invoices/sql_store.go +++ b/invoices/sql_store.go @@ -233,7 +233,7 @@ func (i *SQLStore) AddInvoice(ctx context.Context, CltvDelta: sqldb.SQLInt32( newInvoice.Terms.FinalCltvDelta, ), - Expiry: int32(newInvoice.Terms.Expiry), + Expiry: int32(newInvoice.Terms.Expiry.Seconds()), // Note: keysend invoices don't have a payment request. PaymentRequest: sqldb.SQLStr(string( newInvoice.PaymentRequest), @@ -1598,6 +1598,8 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice, cltvDelta = row.CltvDelta.Int32 } + expiry := time.Duration(row.Expiry) * time.Second + invoice := &Invoice{ SettleIndex: uint64(settleIndex), SettleDate: settledAt, @@ -1606,7 +1608,7 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice, CreationDate: row.CreatedAt.Local(), Terms: ContractTerm{ FinalCltvDelta: cltvDelta, - Expiry: time.Duration(row.Expiry), + Expiry: expiry, PaymentPreimage: preimage, Value: lnwire.MilliSatoshi(row.AmountMsat), PaymentAddr: paymentAddr, diff --git a/sqldb/go.mod b/sqldb/go.mod index 51792d2e0..8fbfeeeea 100644 --- a/sqldb/go.mod +++ b/sqldb/go.mod @@ -7,7 +7,7 @@ require ( github.com/golang-migrate/migrate/v4 v4.17.0 github.com/jackc/pgconn v1.14.3 github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 - github.com/lib/pq v1.10.9 + github.com/jackc/pgx/v5 v5.3.1 github.com/ory/dockertest/v3 v3.10.0 github.com/stretchr/testify v1.9.0 golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 @@ -48,6 +48,7 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/sirupsen/logrus v1.9.2 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect diff --git a/sqldb/go.sum b/sqldb/go.sum index 140eb9996..ffcf7ac56 100644 --- a/sqldb/go.sum +++ b/sqldb/go.sum @@ -83,10 +83,15 @@ github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUO github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= +github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= @@ -118,6 +123,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/seccomp/libseccomp-golang v0.9.2-0.20220502022130-f33da4d89646/go.mod h1:JA8cRccbGaA1s33RQf7Y1+q9gHmZX1yB/z9WDN1C6fg= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= @@ -201,8 +208,9 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/sqldb/migrations.go b/sqldb/migrations.go index 6b104f19b..31319cec8 100644 --- a/sqldb/migrations.go +++ b/sqldb/migrations.go @@ -8,16 +8,71 @@ import ( "net/http" "strings" + "github.com/btcsuite/btclog" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" "github.com/golang-migrate/migrate/v4/source/httpfs" ) +// MigrationTarget is a functional option that can be passed to applyMigrations +// to specify a target version to migrate to. +type MigrationTarget func(mig *migrate.Migrate) error + +var ( + // TargetLatest is a MigrationTarget that migrates to the latest + // version available. + TargetLatest = func(mig *migrate.Migrate) error { + return mig.Up() + } + + // TargetVersion is a MigrationTarget that migrates to the given + // version. + TargetVersion = func(version uint) MigrationTarget { + return func(mig *migrate.Migrate) error { + return mig.Migrate(version) + } + } +) + +// migrationLogger is a logger that wraps the passed btclog.Logger so it can be +// used to log migrations. +type migrationLogger struct { + log btclog.Logger +} + +// Printf is like fmt.Printf. We map this to the target logger based on the +// current log level. +func (m *migrationLogger) Printf(format string, v ...interface{}) { + // Trim trailing newlines from the format. + format = strings.TrimRight(format, "\n") + + switch m.log.Level() { + case btclog.LevelTrace: + m.log.Tracef(format, v...) + case btclog.LevelDebug: + m.log.Debugf(format, v...) + case btclog.LevelInfo: + m.log.Infof(format, v...) + case btclog.LevelWarn: + m.log.Warnf(format, v...) + case btclog.LevelError: + m.log.Errorf(format, v...) + case btclog.LevelCritical: + m.log.Criticalf(format, v...) + case btclog.LevelOff: + } +} + +// Verbose should return true when verbose logging output is wanted +func (m *migrationLogger) Verbose() bool { + return m.log.Level() <= btclog.LevelDebug +} + // applyMigrations executes all database migration files found in the given file // system under the given path, using the passed database driver and database // name. func applyMigrations(fs fs.FS, driver database.Driver, path, - dbName string) error { + dbName string, targetVersion MigrationTarget) error { // With the migrate instance open, we'll create a new migration source // using the embedded file system stored in sqlSchemas. The library @@ -37,7 +92,22 @@ func applyMigrations(fs fs.FS, driver database.Driver, path, if err != nil { return err } - err = sqlMigrate.Up() + + migrationVersion, _, err := sqlMigrate.Version() + if err != nil && !errors.Is(err, migrate.ErrNilVersion) { + log.Errorf("Unable to determine current migration version: %v", + err) + + return err + } + + log.Infof("Applying migrations from version=%v", migrationVersion) + + // Apply our local logger to the migration instance. + sqlMigrate.Log = &migrationLogger{log} + + // Execute the migration based on the target given. + err = targetVersion(sqlMigrate) if err != nil && !errors.Is(err, migrate.ErrNoChange) { return err } diff --git a/sqldb/migrations_test.go b/sqldb/migrations_test.go new file mode 100644 index 000000000..cd55e92cb --- /dev/null +++ b/sqldb/migrations_test.go @@ -0,0 +1,154 @@ +package sqldb + +import ( + "context" + "testing" + + "github.com/lightningnetwork/lnd/sqldb/sqlc" + "github.com/stretchr/testify/require" +) + +// makeMigrationTestDB is a type alias for a function that creates a new test +// database and returns the base database and a function that executes selected +// migrations. +type makeMigrationTestDB = func(*testing.T, uint) (*BaseDB, + func(MigrationTarget) error) + +// TestMigrations is a meta test runner that runs all migration tests. +func TestMigrations(t *testing.T) { + sqliteTestDB := func(t *testing.T, version uint) (*BaseDB, + func(MigrationTarget) error) { + + db := NewTestSqliteDBWithVersion(t, version) + + return db.BaseDB, db.ExecuteMigrations + } + + postgresTestDB := func(t *testing.T, version uint) (*BaseDB, + func(MigrationTarget) error) { + + pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + db := NewTestPostgresDBWithVersion( + t, pgFixture, version, + ) + + return db.BaseDB, db.ExecuteMigrations + } + + tests := []struct { + name string + test func(*testing.T, makeMigrationTestDB) + }{ + { + name: "TestInvoiceExpiryMigration", + test: testInvoiceExpiryMigration, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name+"_SQLite", func(t *testing.T) { + test.test(t, sqliteTestDB) + }) + + t.Run(test.name+"_Postgres", func(t *testing.T) { + test.test(t, postgresTestDB) + }) + + } +} + +// TestInvoiceExpiryMigration tests that the migration from version 3 to 4 +// correctly sets the expiry value of normal invoices to 86400 seconds and +// 2592000 seconds for AMP invoices. +func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) { + t.Parallel() + ctxb := context.Background() + + // Create a new database that already has the first version of the + // native invoice schema. + db, migrate := makeDB(t, 3) + + // Add a few invoices. For simplicity we reuse the payment hash as the + // payment address and payment request hash instead of setting them to + // NULL (to not run into uniqueness constraints). Note that SQLC + // currently doesn't support nullable blob fields porperly. A workaround + // is in progress: https://github.com/sqlc-dev/sqlc/issues/3149 + + // Add an invoice where is_amp will be set to false. + hash1 := []byte{1, 2, 3} + _, err := db.InsertInvoice(ctxb, sqlc.InsertInvoiceParams{ + Hash: hash1, + PaymentAddr: hash1, + PaymentRequestHash: hash1, + Expiry: -123, + IsAmp: false, + }) + require.NoError(t, err) + + // Add an invoice where is_amp will be set to false. + hash2 := []byte{4, 5, 6} + _, err = db.InsertInvoice(ctxb, sqlc.InsertInvoiceParams{ + Hash: hash2, + PaymentAddr: hash2, + PaymentRequestHash: hash2, + Expiry: -456, + IsAmp: true, + }) + require.NoError(t, err) + + // Now, we'll attempt to execute the migration that will fix the expiry + // values by inserting 86400 seconds for non AMP and 2592000 seconds for + // AMP invoices. + err = migrate(TargetVersion(4)) + + invoices, err := db.FilterInvoices(ctxb, sqlc.FilterInvoicesParams{ + AddIndexGet: SQLInt64(1), + NumLimit: 100, + }) + + const ( + // 1 day in seconds. + expiry = int32(86400) + // 30 days in seconds. + expiryAMP = int32(2592000) + ) + + expected := []sqlc.Invoice{ + { + ID: 1, + Hash: hash1, + PaymentAddr: hash1, + PaymentRequestHash: hash1, + Expiry: expiry, + }, + { + ID: 2, + Hash: hash2, + PaymentAddr: hash2, + PaymentRequestHash: hash2, + Expiry: expiryAMP, + IsAmp: true, + }, + } + + for i := range invoices { + // Override the timestamp location as the sql driver will scan + // the timestamp with no location and we can't create such + // timestamps in Golang using the standard time package. + invoices[i].CreatedAt = invoices[i].CreatedAt.UTC() + + // Override the preimage as depending on the driver it is either + // scanned as nil or an empty byte slice. + require.Len(t, invoices[i].Preimage, 0) + invoices[i].Preimage = nil + } + + require.NoError(t, err) + require.Equal(t, expected, invoices) +} diff --git a/sqldb/postgres.go b/sqldb/postgres.go index e6e88c93b..6a4123ed3 100644 --- a/sqldb/postgres.go +++ b/sqldb/postgres.go @@ -2,13 +2,15 @@ package sqldb import ( "database/sql" + "fmt" "net/url" "path" "strings" "time" - postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres" + pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5" _ "github.com/golang-migrate/migrate/v4/source/file" // Read migrations from files. // nolint:lll + _ "github.com/jackc/pgx/v5" "github.com/lightningnetwork/lnd/sqldb/sqlc" ) @@ -19,6 +21,17 @@ var ( // fully executed yet. So this time needs to be chosen correctly to be // longer than the longest expected individual test run time. DefaultPostgresFixtureLifetime = 10 * time.Minute + + // postgresSchemaReplacements is a map of schema strings that need to be + // replaced for postgres. This is needed because we write the schemas to + // work with sqlite primarily but in sqlc's own dialect, and postgres + // has some differences. + postgresSchemaReplacements = map[string]string{ + "BLOB": "BYTEA", + "INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY", + "BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY", + "TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE", + } ) // replacePasswordInDSN takes a DSN string and returns it with the password @@ -79,11 +92,6 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { } log.Infof("Using SQL database '%s'", sanitizedDSN) - dbName, err := getDatabaseNameFromDSN(cfg.Dsn) - if err != nil { - return nil, err - } - rawDB, err := sql.Open("pgx", cfg.Dsn) if err != nil { return nil, err @@ -98,42 +106,45 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) { rawDB.SetMaxIdleConns(maxConns) rawDB.SetConnMaxLifetime(connIdleLifetime) - if !cfg.SkipMigrations { - // Now that the database is open, populate the database with - // our set of schemas based on our embedded in-memory file - // system. - // - // First, we'll need to open up a new migration instance for - // our current target database: Postgres. - driver, err := postgres_migrate.WithInstance( - rawDB, &postgres_migrate.Config{}, - ) - if err != nil { - return nil, err - } - - postgresFS := newReplacerFS(sqlSchemas, map[string]string{ - "BLOB": "BYTEA", - "INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY", - "BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY", - "TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE", - }) - - err = applyMigrations( - postgresFS, driver, "sqlc/migrations", dbName, - ) - if err != nil { - return nil, err - } - } - queries := sqlc.New(rawDB) - return &PostgresStore{ + s := &PostgresStore{ cfg: cfg, BaseDB: &BaseDB{ DB: rawDB, Queries: queries, }, - }, nil + } + + // Execute migrations unless configured to skip them. + if !cfg.SkipMigrations { + err := s.ExecuteMigrations(TargetLatest) + if err != nil { + return nil, fmt.Errorf("error executing migrations: %w", + err) + } + } + + return s, nil +} + +// ExecuteMigrations runs migrations for the Postgres database, depending on the +// target given, either all migrations or up to a given version. +func (s *PostgresStore) ExecuteMigrations(target MigrationTarget) error { + dbName, err := getDatabaseNameFromDSN(s.cfg.Dsn) + if err != nil { + return err + } + + driver, err := pgx_migrate.WithInstance(s.DB, &pgx_migrate.Config{}) + if err != nil { + return fmt.Errorf("error creating postgres migration: %w", err) + } + + // Populate the database with our set of schemas based on our embedded + // in-memory file system. + postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements) + return applyMigrations( + postgresFS, driver, "sqlc/migrations", dbName, target, + ) } diff --git a/sqldb/postgres_fixture.go b/sqldb/postgres_fixture.go index a0f0893ce..da5769c42 100644 --- a/sqldb/postgres_fixture.go +++ b/sqldb/postgres_fixture.go @@ -13,7 +13,7 @@ import ( "testing" "time" - _ "github.com/lib/pq" // Import the postgres driver. + _ "github.com/jackc/pgx/v5" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/require" @@ -91,7 +91,7 @@ func NewTestPgFixture(t *testing.T, expiry time.Duration) *TestPgFixture { var testDB *sql.DB err = pool.Retry(func() error { - testDB, err = sql.Open("postgres", databaseURL) + testDB, err = sql.Open("pgx", databaseURL) if err != nil { return err } @@ -124,28 +124,28 @@ func (f *TestPgFixture) TearDown(t *testing.T) { require.NoError(t, err, "Could not purge resource") } +// randomDBName generates a random database name. +func randomDBName(t *testing.T) string { + randBytes := make([]byte, 8) + _, err := rand.Read(randBytes) + require.NoError(t, err) + + return "test_" + hex.EncodeToString(randBytes) +} + // NewTestPostgresDB is a helper function that creates a Postgres database for // testing using the given fixture. func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore { t.Helper() - // Create random database name. - randBytes := make([]byte, 8) - _, err := rand.Read(randBytes) - if err != nil { - t.Fatal(err) - } - - dbName := "test_" + hex.EncodeToString(randBytes) + dbName := randomDBName(t) t.Logf("Creating new Postgres DB '%s' for testing", dbName) - _, err = fixture.db.ExecContext( + _, err := fixture.db.ExecContext( context.Background(), "CREATE DATABASE "+dbName, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) cfg := fixture.GetConfig(dbName) store, err := NewPostgresStore(cfg) @@ -153,3 +153,30 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore { return store } + +// NewTestPostgresDBWithVersion is a helper function that creates a Postgres +// database for testing and migrates it to the given version. +func NewTestPostgresDBWithVersion(t *testing.T, fixture *TestPgFixture, + version uint) *PostgresStore { + + t.Helper() + + t.Logf("Creating new Postgres DB for testing, migrating to version %d", + version) + + dbName := randomDBName(t) + _, err := fixture.db.ExecContext( + context.Background(), "CREATE DATABASE "+dbName, + ) + require.NoError(t, err) + + storeCfg := fixture.GetConfig(dbName) + storeCfg.SkipMigrations = true + store, err := NewPostgresStore(storeCfg) + require.NoError(t, err) + + err = store.ExecuteMigrations(TargetVersion(version)) + require.NoError(t, err) + + return store +} diff --git a/sqldb/postgres_test.go b/sqldb/postgres_test.go index c31daf04b..22a29f885 100644 --- a/sqldb/postgres_test.go +++ b/sqldb/postgres_test.go @@ -9,5 +9,21 @@ import ( // NewTestDB is a helper function that creates a Postgres database for testing. func NewTestDB(t *testing.T) *PostgresStore { - return NewTestPostgresDB(t) + pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + return NewTestPostgresDB(t, pgFixture) +} + +// NewTestDBWithVersion is a helper function that creates a Postgres database +// for testing and migrates it to the given version. +func NewTestDBWithVersion(t *testing.T, version uint) *PostgresStore { + pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + + return NewTestPostgresDBWithVersion(t, pgFixture, version) } diff --git a/sqldb/sqlc/migrations/000004_invoice_expiry_fix.down.sql b/sqldb/sqlc/migrations/000004_invoice_expiry_fix.down.sql new file mode 100644 index 000000000..5a4ff3ab2 --- /dev/null +++ b/sqldb/sqlc/migrations/000004_invoice_expiry_fix.down.sql @@ -0,0 +1,2 @@ +-- Given that all expiries are changed in this migration we won't be able to +-- roll back to the previous values. diff --git a/sqldb/sqlc/migrations/000004_invoice_expiry_fix.up.sql b/sqldb/sqlc/migrations/000004_invoice_expiry_fix.up.sql new file mode 100644 index 000000000..02311f830 --- /dev/null +++ b/sqldb/sqlc/migrations/000004_invoice_expiry_fix.up.sql @@ -0,0 +1,14 @@ +-- Update the expiry for all records in the invoices table. This is needed as +-- previously we stored raw time.Duration values which are 64 bit integers and +-- are used to express duration in nanoseconds however the intent is to store +-- invoice expiry in seconds. + +-- Update the expiry to 86400 seconds (24 hours) for non-AMP invoices. +UPDATE invoices +SET expiry = 86400 +WHERE is_amp = FALSE; + +-- Update the expiry to 2592000 seconds (30 days) for AMP invoices +UPDATE invoices +SET expiry = 2592000 +WHERE is_amp = TRUE; diff --git a/sqldb/sqlite.go b/sqldb/sqlite.go index 705d5cc47..99e55d6ea 100644 --- a/sqldb/sqlite.go +++ b/sqldb/sqlite.go @@ -26,6 +26,16 @@ const ( sqliteTxLockImmediate = "_txlock=immediate" ) +var ( + // sqliteSchemaReplacements is a map of schema strings that need to be + // replaced for sqlite. This is needed because sqlite doesn't directly + // support the BIGINT type for primary keys, so we need to replace it + // with INTEGER. + sqliteSchemaReplacements = map[string]string{ + "BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY", + } +) + // SqliteStore is a database store implementation that uses a sqlite backend. type SqliteStore struct { cfg *SqliteConfig @@ -95,46 +105,44 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) { db.SetMaxOpenConns(defaultMaxConns) db.SetMaxIdleConns(defaultMaxConns) db.SetConnMaxLifetime(connIdleLifetime) - - if !cfg.SkipMigrations { - // Now that the database is open, populate the database with - // our set of schemas based on our embedded in-memory file - // system. - // - // First, we'll need to open up a new migration instance for - // our current target database: sqlite. - driver, err := sqlite_migrate.WithInstance( - db, &sqlite_migrate.Config{}, - ) - if err != nil { - return nil, err - } - - // We use INTEGER PRIMARY KEY for sqlite, because it acts as a - // ROWID alias which is 8 bytes big and also autoincrements. - // It's important to use the ROWID as a primary key because the - // key look ups are almost twice as fast - sqliteFS := newReplacerFS(sqlSchemas, map[string]string{ - "BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY", - }) - - err = applyMigrations( - sqliteFS, driver, "sqlc/migrations", "sqlc", - ) - if err != nil { - return nil, err - } - } - queries := sqlc.New(db) - return &SqliteStore{ + s := &SqliteStore{ cfg: cfg, BaseDB: &BaseDB{ DB: db, Queries: queries, }, - }, nil + } + + // Execute migrations unless configured to skip them. + if !cfg.SkipMigrations { + if err := s.ExecuteMigrations(TargetLatest); err != nil { + return nil, fmt.Errorf("error executing migrations: "+ + "%w", err) + + } + } + + return s, nil +} + +// ExecuteMigrations runs migrations for the sqlite database, depending on the +// target given, either all migrations or up to a given version. +func (s *SqliteStore) ExecuteMigrations(target MigrationTarget) error { + driver, err := sqlite_migrate.WithInstance( + s.DB, &sqlite_migrate.Config{}, + ) + if err != nil { + return fmt.Errorf("error creating sqlite migration: %w", err) + } + + // Populate the database with our set of schemas based on our embedded + // in-memory file system. + sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements) + return applyMigrations( + sqliteFS, driver, "sqlc/migrations", "sqlite", target, + ) } // NewTestSqliteDB is a helper function that creates an SQLite database for @@ -158,3 +166,29 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore { return sqlDB } + +// NewTestSqliteDBWithVersion is a helper function that creates an SQLite +// database for testing and migrates it to the given version. +func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore { + t.Helper() + + t.Logf("Creating new SQLite DB for testing, migrating to version %d", + version) + + // TODO(roasbeef): if we pass :memory: for the file name, then we get + // an in mem version to speed up tests + dbFileName := filepath.Join(t.TempDir(), "tmp.db") + sqlDB, err := NewSqliteStore(&SqliteConfig{ + SkipMigrations: true, + }, dbFileName) + require.NoError(t, err) + + err = sqlDB.ExecuteMigrations(TargetVersion(version)) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, sqlDB.DB.Close()) + }) + + return sqlDB +} diff --git a/sqldb/sqlite_test.go b/sqldb/sqlite_test.go index 58c94b351..9dfb875ea 100644 --- a/sqldb/sqlite_test.go +++ b/sqldb/sqlite_test.go @@ -11,3 +11,9 @@ import ( func NewTestDB(t *testing.T) *SqliteStore { return NewTestSqliteDB(t) } + +// NewTestDBWithVersion is a helper function that creates an SQLite database +// for testing and migrates it to the given version. +func NewTestDBWithVersion(t *testing.T, version uint) *SqliteStore { + return NewTestSqliteDBWithVersion(t, version) +}