Merge pull request #8855 from bhandras/invoice-expiry-migration

sqldb+invoices: fix incorrectly stored invoice expiries when using native SQL
This commit is contained in:
Oliver Gugger 2024-07-09 02:06:28 -06:00 committed by GitHub
commit a9655357ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 448 additions and 94 deletions

View File

@ -91,6 +91,12 @@
## Testing
## Database
* [Migrate](https://github.com/lightningnetwork/lnd/pull/8855) incorrectly
stored invoice expiry values. This migration only affects users of native SQL
invoice database. Invoices with incorrect expiry values will be updated to
24-hour expiry, which is the default behavior in LND.
## Code Health
## Tooling and Documentation

1
go.mod
View File

@ -112,6 +112,7 @@ require (
github.com/jackc/pgproto3/v2 v2.3.3 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgtype v1.14.0 // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jackc/puddle v1.3.0 // indirect
github.com/jonboulle/clockwork v0.2.2 // indirect
github.com/json-iterator/go v1.1.11 // indirect

6
go.sum
View File

@ -365,6 +365,8 @@ github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQ
github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs=
github.com/jackc/pgx/v4 v4.18.2 h1:xVpYkNR5pk5bMCZGfClbO962UIqVABcAGt7ha1s/FeU=
github.com/jackc/pgx/v4 v4.18.2/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
@ -554,8 +556,8 @@ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rogpeppe/fastuuid v1.2.0 h1:Ppwyp6VYCF1nvBTXL3trRso7mXMlRrw9ooo375wvi2s=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=

View File

@ -69,7 +69,7 @@ func randInvoice(value lnwire.MilliSatoshi) (*invpkg.Invoice, error) {
i := &invpkg.Invoice{
CreationDate: testNow,
Terms: invpkg.ContractTerm{
Expiry: 4000,
Expiry: time.Duration(4000) * time.Second,
PaymentPreimage: &pre,
PaymentAddr: payAddr,
Value: value,

View File

@ -233,7 +233,7 @@ func (i *SQLStore) AddInvoice(ctx context.Context,
CltvDelta: sqldb.SQLInt32(
newInvoice.Terms.FinalCltvDelta,
),
Expiry: int32(newInvoice.Terms.Expiry),
Expiry: int32(newInvoice.Terms.Expiry.Seconds()),
// Note: keysend invoices don't have a payment request.
PaymentRequest: sqldb.SQLStr(string(
newInvoice.PaymentRequest),
@ -1598,6 +1598,8 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
cltvDelta = row.CltvDelta.Int32
}
expiry := time.Duration(row.Expiry) * time.Second
invoice := &Invoice{
SettleIndex: uint64(settleIndex),
SettleDate: settledAt,
@ -1606,7 +1608,7 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
CreationDate: row.CreatedAt.Local(),
Terms: ContractTerm{
FinalCltvDelta: cltvDelta,
Expiry: time.Duration(row.Expiry),
Expiry: expiry,
PaymentPreimage: preimage,
Value: lnwire.MilliSatoshi(row.AmountMsat),
PaymentAddr: paymentAddr,

View File

@ -7,7 +7,7 @@ require (
github.com/golang-migrate/migrate/v4 v4.17.0
github.com/jackc/pgconn v1.14.3
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438
github.com/lib/pq v1.10.9
github.com/jackc/pgx/v5 v5.3.1
github.com/ory/dockertest/v3 v3.10.0
github.com/stretchr/testify v1.9.0
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8
@ -48,6 +48,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
github.com/sirupsen/logrus v1.9.2 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect

View File

@ -83,10 +83,15 @@ github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUO
github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
@ -118,6 +123,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/seccomp/libseccomp-golang v0.9.2-0.20220502022130-f33da4d89646/go.mod h1:JA8cRccbGaA1s33RQf7Y1+q9gHmZX1yB/z9WDN1C6fg=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
@ -201,8 +208,9 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@ -8,16 +8,71 @@ import (
"net/http"
"strings"
"github.com/btcsuite/btclog"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/source/httpfs"
)
// MigrationTarget is a functional option that can be passed to applyMigrations
// to specify a target version to migrate to.
type MigrationTarget func(mig *migrate.Migrate) error
var (
// TargetLatest is a MigrationTarget that migrates to the latest
// version available.
TargetLatest = func(mig *migrate.Migrate) error {
return mig.Up()
}
// TargetVersion is a MigrationTarget that migrates to the given
// version.
TargetVersion = func(version uint) MigrationTarget {
return func(mig *migrate.Migrate) error {
return mig.Migrate(version)
}
}
)
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
// used to log migrations.
type migrationLogger struct {
log btclog.Logger
}
// Printf is like fmt.Printf. We map this to the target logger based on the
// current log level.
func (m *migrationLogger) Printf(format string, v ...interface{}) {
// Trim trailing newlines from the format.
format = strings.TrimRight(format, "\n")
switch m.log.Level() {
case btclog.LevelTrace:
m.log.Tracef(format, v...)
case btclog.LevelDebug:
m.log.Debugf(format, v...)
case btclog.LevelInfo:
m.log.Infof(format, v...)
case btclog.LevelWarn:
m.log.Warnf(format, v...)
case btclog.LevelError:
m.log.Errorf(format, v...)
case btclog.LevelCritical:
m.log.Criticalf(format, v...)
case btclog.LevelOff:
}
}
// Verbose should return true when verbose logging output is wanted
func (m *migrationLogger) Verbose() bool {
return m.log.Level() <= btclog.LevelDebug
}
// applyMigrations executes all database migration files found in the given file
// system under the given path, using the passed database driver and database
// name.
func applyMigrations(fs fs.FS, driver database.Driver, path,
dbName string) error {
dbName string, targetVersion MigrationTarget) error {
// With the migrate instance open, we'll create a new migration source
// using the embedded file system stored in sqlSchemas. The library
@ -37,7 +92,22 @@ func applyMigrations(fs fs.FS, driver database.Driver, path,
if err != nil {
return err
}
err = sqlMigrate.Up()
migrationVersion, _, err := sqlMigrate.Version()
if err != nil && !errors.Is(err, migrate.ErrNilVersion) {
log.Errorf("Unable to determine current migration version: %v",
err)
return err
}
log.Infof("Applying migrations from version=%v", migrationVersion)
// Apply our local logger to the migration instance.
sqlMigrate.Log = &migrationLogger{log}
// Execute the migration based on the target given.
err = targetVersion(sqlMigrate)
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return err
}

154
sqldb/migrations_test.go Normal file
View File

@ -0,0 +1,154 @@
package sqldb
import (
"context"
"testing"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
"github.com/stretchr/testify/require"
)
// makeMigrationTestDB is a type alias for a function that creates a new test
// database and returns the base database and a function that executes selected
// migrations.
type makeMigrationTestDB = func(*testing.T, uint) (*BaseDB,
func(MigrationTarget) error)
// TestMigrations is a meta test runner that runs all migration tests.
func TestMigrations(t *testing.T) {
sqliteTestDB := func(t *testing.T, version uint) (*BaseDB,
func(MigrationTarget) error) {
db := NewTestSqliteDBWithVersion(t, version)
return db.BaseDB, db.ExecuteMigrations
}
postgresTestDB := func(t *testing.T, version uint) (*BaseDB,
func(MigrationTarget) error) {
pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime)
t.Cleanup(func() {
pgFixture.TearDown(t)
})
db := NewTestPostgresDBWithVersion(
t, pgFixture, version,
)
return db.BaseDB, db.ExecuteMigrations
}
tests := []struct {
name string
test func(*testing.T, makeMigrationTestDB)
}{
{
name: "TestInvoiceExpiryMigration",
test: testInvoiceExpiryMigration,
},
}
for _, test := range tests {
test := test
t.Run(test.name+"_SQLite", func(t *testing.T) {
test.test(t, sqliteTestDB)
})
t.Run(test.name+"_Postgres", func(t *testing.T) {
test.test(t, postgresTestDB)
})
}
}
// TestInvoiceExpiryMigration tests that the migration from version 3 to 4
// correctly sets the expiry value of normal invoices to 86400 seconds and
// 2592000 seconds for AMP invoices.
func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) {
t.Parallel()
ctxb := context.Background()
// Create a new database that already has the first version of the
// native invoice schema.
db, migrate := makeDB(t, 3)
// Add a few invoices. For simplicity we reuse the payment hash as the
// payment address and payment request hash instead of setting them to
// NULL (to not run into uniqueness constraints). Note that SQLC
// currently doesn't support nullable blob fields porperly. A workaround
// is in progress: https://github.com/sqlc-dev/sqlc/issues/3149
// Add an invoice where is_amp will be set to false.
hash1 := []byte{1, 2, 3}
_, err := db.InsertInvoice(ctxb, sqlc.InsertInvoiceParams{
Hash: hash1,
PaymentAddr: hash1,
PaymentRequestHash: hash1,
Expiry: -123,
IsAmp: false,
})
require.NoError(t, err)
// Add an invoice where is_amp will be set to false.
hash2 := []byte{4, 5, 6}
_, err = db.InsertInvoice(ctxb, sqlc.InsertInvoiceParams{
Hash: hash2,
PaymentAddr: hash2,
PaymentRequestHash: hash2,
Expiry: -456,
IsAmp: true,
})
require.NoError(t, err)
// Now, we'll attempt to execute the migration that will fix the expiry
// values by inserting 86400 seconds for non AMP and 2592000 seconds for
// AMP invoices.
err = migrate(TargetVersion(4))
invoices, err := db.FilterInvoices(ctxb, sqlc.FilterInvoicesParams{
AddIndexGet: SQLInt64(1),
NumLimit: 100,
})
const (
// 1 day in seconds.
expiry = int32(86400)
// 30 days in seconds.
expiryAMP = int32(2592000)
)
expected := []sqlc.Invoice{
{
ID: 1,
Hash: hash1,
PaymentAddr: hash1,
PaymentRequestHash: hash1,
Expiry: expiry,
},
{
ID: 2,
Hash: hash2,
PaymentAddr: hash2,
PaymentRequestHash: hash2,
Expiry: expiryAMP,
IsAmp: true,
},
}
for i := range invoices {
// Override the timestamp location as the sql driver will scan
// the timestamp with no location and we can't create such
// timestamps in Golang using the standard time package.
invoices[i].CreatedAt = invoices[i].CreatedAt.UTC()
// Override the preimage as depending on the driver it is either
// scanned as nil or an empty byte slice.
require.Len(t, invoices[i].Preimage, 0)
invoices[i].Preimage = nil
}
require.NoError(t, err)
require.Equal(t, expected, invoices)
}

View File

@ -2,13 +2,15 @@ package sqldb
import (
"database/sql"
"fmt"
"net/url"
"path"
"strings"
"time"
postgres_migrate "github.com/golang-migrate/migrate/v4/database/postgres"
pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
_ "github.com/golang-migrate/migrate/v4/source/file" // Read migrations from files. // nolint:lll
_ "github.com/jackc/pgx/v5"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
)
@ -19,6 +21,17 @@ var (
// fully executed yet. So this time needs to be chosen correctly to be
// longer than the longest expected individual test run time.
DefaultPostgresFixtureLifetime = 10 * time.Minute
// postgresSchemaReplacements is a map of schema strings that need to be
// replaced for postgres. This is needed because we write the schemas to
// work with sqlite primarily but in sqlc's own dialect, and postgres
// has some differences.
postgresSchemaReplacements = map[string]string{
"BLOB": "BYTEA",
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
}
)
// replacePasswordInDSN takes a DSN string and returns it with the password
@ -79,11 +92,6 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
}
log.Infof("Using SQL database '%s'", sanitizedDSN)
dbName, err := getDatabaseNameFromDSN(cfg.Dsn)
if err != nil {
return nil, err
}
rawDB, err := sql.Open("pgx", cfg.Dsn)
if err != nil {
return nil, err
@ -98,42 +106,45 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
rawDB.SetMaxIdleConns(maxConns)
rawDB.SetConnMaxLifetime(connIdleLifetime)
if !cfg.SkipMigrations {
// Now that the database is open, populate the database with
// our set of schemas based on our embedded in-memory file
// system.
//
// First, we'll need to open up a new migration instance for
// our current target database: Postgres.
driver, err := postgres_migrate.WithInstance(
rawDB, &postgres_migrate.Config{},
)
if err != nil {
return nil, err
}
postgresFS := newReplacerFS(sqlSchemas, map[string]string{
"BLOB": "BYTEA",
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
})
err = applyMigrations(
postgresFS, driver, "sqlc/migrations", dbName,
)
if err != nil {
return nil, err
}
}
queries := sqlc.New(rawDB)
return &PostgresStore{
s := &PostgresStore{
cfg: cfg,
BaseDB: &BaseDB{
DB: rawDB,
Queries: queries,
},
}, nil
}
// Execute migrations unless configured to skip them.
if !cfg.SkipMigrations {
err := s.ExecuteMigrations(TargetLatest)
if err != nil {
return nil, fmt.Errorf("error executing migrations: %w",
err)
}
}
return s, nil
}
// ExecuteMigrations runs migrations for the Postgres database, depending on the
// target given, either all migrations or up to a given version.
func (s *PostgresStore) ExecuteMigrations(target MigrationTarget) error {
dbName, err := getDatabaseNameFromDSN(s.cfg.Dsn)
if err != nil {
return err
}
driver, err := pgx_migrate.WithInstance(s.DB, &pgx_migrate.Config{})
if err != nil {
return fmt.Errorf("error creating postgres migration: %w", err)
}
// Populate the database with our set of schemas based on our embedded
// in-memory file system.
postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements)
return applyMigrations(
postgresFS, driver, "sqlc/migrations", dbName, target,
)
}

View File

@ -13,7 +13,7 @@ import (
"testing"
"time"
_ "github.com/lib/pq" // Import the postgres driver.
_ "github.com/jackc/pgx/v5"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/require"
@ -91,7 +91,7 @@ func NewTestPgFixture(t *testing.T, expiry time.Duration) *TestPgFixture {
var testDB *sql.DB
err = pool.Retry(func() error {
testDB, err = sql.Open("postgres", databaseURL)
testDB, err = sql.Open("pgx", databaseURL)
if err != nil {
return err
}
@ -124,28 +124,28 @@ func (f *TestPgFixture) TearDown(t *testing.T) {
require.NoError(t, err, "Could not purge resource")
}
// randomDBName generates a random database name.
func randomDBName(t *testing.T) string {
randBytes := make([]byte, 8)
_, err := rand.Read(randBytes)
require.NoError(t, err)
return "test_" + hex.EncodeToString(randBytes)
}
// NewTestPostgresDB is a helper function that creates a Postgres database for
// testing using the given fixture.
func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore {
t.Helper()
// Create random database name.
randBytes := make([]byte, 8)
_, err := rand.Read(randBytes)
if err != nil {
t.Fatal(err)
}
dbName := "test_" + hex.EncodeToString(randBytes)
dbName := randomDBName(t)
t.Logf("Creating new Postgres DB '%s' for testing", dbName)
_, err = fixture.db.ExecContext(
_, err := fixture.db.ExecContext(
context.Background(), "CREATE DATABASE "+dbName,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
cfg := fixture.GetConfig(dbName)
store, err := NewPostgresStore(cfg)
@ -153,3 +153,30 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore {
return store
}
// NewTestPostgresDBWithVersion is a helper function that creates a Postgres
// database for testing and migrates it to the given version.
func NewTestPostgresDBWithVersion(t *testing.T, fixture *TestPgFixture,
version uint) *PostgresStore {
t.Helper()
t.Logf("Creating new Postgres DB for testing, migrating to version %d",
version)
dbName := randomDBName(t)
_, err := fixture.db.ExecContext(
context.Background(), "CREATE DATABASE "+dbName,
)
require.NoError(t, err)
storeCfg := fixture.GetConfig(dbName)
storeCfg.SkipMigrations = true
store, err := NewPostgresStore(storeCfg)
require.NoError(t, err)
err = store.ExecuteMigrations(TargetVersion(version))
require.NoError(t, err)
return store
}

View File

@ -9,5 +9,21 @@ import (
// NewTestDB is a helper function that creates a Postgres database for testing.
func NewTestDB(t *testing.T) *PostgresStore {
return NewTestPostgresDB(t)
pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime)
t.Cleanup(func() {
pgFixture.TearDown(t)
})
return NewTestPostgresDB(t, pgFixture)
}
// NewTestDBWithVersion is a helper function that creates a Postgres database
// for testing and migrates it to the given version.
func NewTestDBWithVersion(t *testing.T, version uint) *PostgresStore {
pgFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime)
t.Cleanup(func() {
pgFixture.TearDown(t)
})
return NewTestPostgresDBWithVersion(t, pgFixture, version)
}

View File

@ -0,0 +1,2 @@
-- Given that all expiries are changed in this migration we won't be able to
-- roll back to the previous values.

View File

@ -0,0 +1,14 @@
-- Update the expiry for all records in the invoices table. This is needed as
-- previously we stored raw time.Duration values which are 64 bit integers and
-- are used to express duration in nanoseconds however the intent is to store
-- invoice expiry in seconds.
-- Update the expiry to 86400 seconds (24 hours) for non-AMP invoices.
UPDATE invoices
SET expiry = 86400
WHERE is_amp = FALSE;
-- Update the expiry to 2592000 seconds (30 days) for AMP invoices
UPDATE invoices
SET expiry = 2592000
WHERE is_amp = TRUE;

View File

@ -26,6 +26,16 @@ const (
sqliteTxLockImmediate = "_txlock=immediate"
)
var (
// sqliteSchemaReplacements is a map of schema strings that need to be
// replaced for sqlite. This is needed because sqlite doesn't directly
// support the BIGINT type for primary keys, so we need to replace it
// with INTEGER.
sqliteSchemaReplacements = map[string]string{
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
}
)
// SqliteStore is a database store implementation that uses a sqlite backend.
type SqliteStore struct {
cfg *SqliteConfig
@ -95,46 +105,44 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
db.SetMaxOpenConns(defaultMaxConns)
db.SetMaxIdleConns(defaultMaxConns)
db.SetConnMaxLifetime(connIdleLifetime)
if !cfg.SkipMigrations {
// Now that the database is open, populate the database with
// our set of schemas based on our embedded in-memory file
// system.
//
// First, we'll need to open up a new migration instance for
// our current target database: sqlite.
driver, err := sqlite_migrate.WithInstance(
db, &sqlite_migrate.Config{},
)
if err != nil {
return nil, err
}
// We use INTEGER PRIMARY KEY for sqlite, because it acts as a
// ROWID alias which is 8 bytes big and also autoincrements.
// It's important to use the ROWID as a primary key because the
// key look ups are almost twice as fast
sqliteFS := newReplacerFS(sqlSchemas, map[string]string{
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
})
err = applyMigrations(
sqliteFS, driver, "sqlc/migrations", "sqlc",
)
if err != nil {
return nil, err
}
}
queries := sqlc.New(db)
return &SqliteStore{
s := &SqliteStore{
cfg: cfg,
BaseDB: &BaseDB{
DB: db,
Queries: queries,
},
}, nil
}
// Execute migrations unless configured to skip them.
if !cfg.SkipMigrations {
if err := s.ExecuteMigrations(TargetLatest); err != nil {
return nil, fmt.Errorf("error executing migrations: "+
"%w", err)
}
}
return s, nil
}
// ExecuteMigrations runs migrations for the sqlite database, depending on the
// target given, either all migrations or up to a given version.
func (s *SqliteStore) ExecuteMigrations(target MigrationTarget) error {
driver, err := sqlite_migrate.WithInstance(
s.DB, &sqlite_migrate.Config{},
)
if err != nil {
return fmt.Errorf("error creating sqlite migration: %w", err)
}
// Populate the database with our set of schemas based on our embedded
// in-memory file system.
sqliteFS := newReplacerFS(sqlSchemas, sqliteSchemaReplacements)
return applyMigrations(
sqliteFS, driver, "sqlc/migrations", "sqlite", target,
)
}
// NewTestSqliteDB is a helper function that creates an SQLite database for
@ -158,3 +166,29 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore {
return sqlDB
}
// NewTestSqliteDBWithVersion is a helper function that creates an SQLite
// database for testing and migrates it to the given version.
func NewTestSqliteDBWithVersion(t *testing.T, version uint) *SqliteStore {
t.Helper()
t.Logf("Creating new SQLite DB for testing, migrating to version %d",
version)
// TODO(roasbeef): if we pass :memory: for the file name, then we get
// an in mem version to speed up tests
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
sqlDB, err := NewSqliteStore(&SqliteConfig{
SkipMigrations: true,
}, dbFileName)
require.NoError(t, err)
err = sqlDB.ExecuteMigrations(TargetVersion(version))
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, sqlDB.DB.Close())
})
return sqlDB
}

View File

@ -11,3 +11,9 @@ import (
func NewTestDB(t *testing.T) *SqliteStore {
return NewTestSqliteDB(t)
}
// NewTestDBWithVersion is a helper function that creates an SQLite database
// for testing and migrates it to the given version.
func NewTestDBWithVersion(t *testing.T, version uint) *SqliteStore {
return NewTestSqliteDBWithVersion(t, version)
}