mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-04 17:00:20 +02:00
Merge pull request #8831 from bhandras/sql-invoice-migration
invoices: migrate KV invoices to native SQL for users of KV SQL backends
This commit is contained in:
commit
6cabc74c20
@ -51,6 +51,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/rpcperms"
|
"github.com/lightningnetwork/lnd/rpcperms"
|
||||||
"github.com/lightningnetwork/lnd/signal"
|
"github.com/lightningnetwork/lnd/signal"
|
||||||
"github.com/lightningnetwork/lnd/sqldb"
|
"github.com/lightningnetwork/lnd/sqldb"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||||
"github.com/lightningnetwork/lnd/sweep"
|
"github.com/lightningnetwork/lnd/sweep"
|
||||||
"github.com/lightningnetwork/lnd/walletunlocker"
|
"github.com/lightningnetwork/lnd/walletunlocker"
|
||||||
"github.com/lightningnetwork/lnd/watchtower"
|
"github.com/lightningnetwork/lnd/watchtower"
|
||||||
@ -60,6 +61,16 @@ import (
|
|||||||
"gopkg.in/macaroon-bakery.v2/bakery"
|
"gopkg.in/macaroon-bakery.v2/bakery"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// invoiceMigrationBatchSize is the number of invoices that will be
|
||||||
|
// migrated in a single batch.
|
||||||
|
invoiceMigrationBatchSize = 1000
|
||||||
|
|
||||||
|
// invoiceMigration is the version of the migration that will be used to
|
||||||
|
// migrate invoices from the kvdb to the sql database.
|
||||||
|
invoiceMigration = 7
|
||||||
|
)
|
||||||
|
|
||||||
// GrpcRegistrar is an interface that must be satisfied by an external subserver
|
// GrpcRegistrar is an interface that must be satisfied by an external subserver
|
||||||
// that wants to be able to register its own gRPC server onto lnd's main
|
// that wants to be able to register its own gRPC server onto lnd's main
|
||||||
// grpc.Server instance.
|
// grpc.Server instance.
|
||||||
@ -932,10 +943,10 @@ type DatabaseInstances struct {
|
|||||||
// the btcwallet's loader.
|
// the btcwallet's loader.
|
||||||
WalletDB btcwallet.LoaderOption
|
WalletDB btcwallet.LoaderOption
|
||||||
|
|
||||||
// NativeSQLStore is a pointer to a native SQL store that can be used
|
// NativeSQLStore holds a reference to the native SQL store that can
|
||||||
// for native SQL queries for tables that already support it. This may
|
// be used for native SQL queries for tables that already support it.
|
||||||
// be nil if the use-native-sql flag was not set.
|
// This may be nil if the use-native-sql flag was not set.
|
||||||
NativeSQLStore *sqldb.BaseDB
|
NativeSQLStore sqldb.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultDatabaseBuilder is a type that builds the default database backends
|
// DefaultDatabaseBuilder is a type that builds the default database backends
|
||||||
@ -1038,7 +1049,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cleanUp()
|
cleanUp()
|
||||||
|
|
||||||
err := fmt.Errorf("unable to open graph DB: %w", err)
|
err = fmt.Errorf("unable to open graph DB: %w", err)
|
||||||
d.logger.Error(err)
|
d.logger.Error(err)
|
||||||
|
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@ -1072,51 +1083,69 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
|
|||||||
case err != nil:
|
case err != nil:
|
||||||
cleanUp()
|
cleanUp()
|
||||||
|
|
||||||
err := fmt.Errorf("unable to open graph DB: %w", err)
|
err = fmt.Errorf("unable to open graph DB: %w", err)
|
||||||
d.logger.Error(err)
|
d.logger.Error(err)
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate a native SQL invoice store if the flag is set.
|
// Instantiate a native SQL store if the flag is set.
|
||||||
if d.cfg.DB.UseNativeSQL {
|
if d.cfg.DB.UseNativeSQL {
|
||||||
// KV invoice db resides in the same database as the channel
|
migrations := sqldb.GetMigrations()
|
||||||
// state DB. Let's query the database to see if we have any
|
|
||||||
// invoices there. If we do, we won't allow the user to start
|
// If the user has not explicitly disabled the SQL invoice
|
||||||
// lnd with native SQL enabled, as we don't currently migrate
|
// migration, attach the custom migration function to invoice
|
||||||
// the invoices to the new database schema.
|
// migration (version 7). Even if this custom migration is
|
||||||
invoiceSlice, err := dbs.ChanStateDB.QueryInvoices(
|
// disabled, the regular native SQL store migrations will still
|
||||||
ctx, invoices.InvoiceQuery{
|
// run. If the database version is already above this custom
|
||||||
NumMaxInvoices: 1,
|
// migration's version (7), it will be skipped permanently,
|
||||||
},
|
// regardless of the flag.
|
||||||
|
if !d.cfg.DB.SkipSQLInvoiceMigration {
|
||||||
|
migrationFn := func(tx *sqlc.Queries) error {
|
||||||
|
return invoices.MigrateInvoicesToSQL(
|
||||||
|
ctx, dbs.ChanStateDB.Backend,
|
||||||
|
dbs.ChanStateDB, tx,
|
||||||
|
invoiceMigrationBatchSize,
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we attach the custom migration function to
|
||||||
|
// the correct migration version.
|
||||||
|
for i := 0; i < len(migrations); i++ {
|
||||||
|
if migrations[i].Version != invoiceMigration {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
migrations[i].MigrationFn = migrationFn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to apply all migrations to the native SQL store
|
||||||
|
// before we can use it.
|
||||||
|
err = dbs.NativeSQLStore.ApplyAllMigrations(ctx, migrations)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cleanUp()
|
cleanUp()
|
||||||
d.logger.Errorf("Unable to query KV invoice DB: %v",
|
err = fmt.Errorf("faild to run migrations for the "+
|
||||||
err)
|
"native SQL store: %w", err)
|
||||||
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(invoiceSlice.Invoices) > 0 {
|
|
||||||
cleanUp()
|
|
||||||
err := fmt.Errorf("found invoices in the KV invoice " +
|
|
||||||
"DB, migration to native SQL is not yet " +
|
|
||||||
"supported")
|
|
||||||
d.logger.Error(err)
|
d.logger.Error(err)
|
||||||
|
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// With the DB ready and migrations applied, we can now create
|
||||||
|
// the base DB and transaction executor for the native SQL
|
||||||
|
// invoice store.
|
||||||
|
baseDB := dbs.NativeSQLStore.GetBaseDB()
|
||||||
executor := sqldb.NewTransactionExecutor(
|
executor := sqldb.NewTransactionExecutor(
|
||||||
dbs.NativeSQLStore,
|
baseDB, func(tx *sql.Tx) invoices.SQLInvoiceQueries {
|
||||||
func(tx *sql.Tx) invoices.SQLInvoiceQueries {
|
return baseDB.WithTx(tx)
|
||||||
return dbs.NativeSQLStore.WithTx(tx)
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
dbs.InvoiceDB = invoices.NewSQLStore(
|
sqlInvoiceDB := invoices.NewSQLStore(
|
||||||
executor, clock.NewDefaultClock(),
|
executor, clock.NewDefaultClock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dbs.InvoiceDB = sqlInvoiceDB
|
||||||
} else {
|
} else {
|
||||||
dbs.InvoiceDB = dbs.ChanStateDB
|
dbs.InvoiceDB = dbs.ChanStateDB
|
||||||
}
|
}
|
||||||
@ -1129,7 +1158,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cleanUp()
|
cleanUp()
|
||||||
|
|
||||||
err := fmt.Errorf("unable to open %s database: %w",
|
err = fmt.Errorf("unable to open %s database: %w",
|
||||||
lncfg.NSTowerClientDB, err)
|
lncfg.NSTowerClientDB, err)
|
||||||
d.logger.Error(err)
|
d.logger.Error(err)
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@ -1144,7 +1173,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cleanUp()
|
cleanUp()
|
||||||
|
|
||||||
err := fmt.Errorf("unable to open %s database: %w",
|
err = fmt.Errorf("unable to open %s database: %w",
|
||||||
lncfg.NSTowerServerDB, err)
|
lncfg.NSTowerServerDB, err)
|
||||||
d.logger.Error(err)
|
d.logger.Error(err)
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -264,6 +264,11 @@ The underlying functionality between those two options remain the same.
|
|||||||
transactions can run at once, increasing efficiency. Includes several bugfixes
|
transactions can run at once, increasing efficiency. Includes several bugfixes
|
||||||
to allow this to work properly.
|
to allow this to work properly.
|
||||||
|
|
||||||
|
* [Migrate KV invoices to
|
||||||
|
SQL](https://github.com/lightningnetwork/lnd/pull/8831) as part of a larger
|
||||||
|
effort to support SQL databases natively in LND.
|
||||||
|
|
||||||
|
|
||||||
## Code Health
|
## Code Health
|
||||||
|
|
||||||
* A code refactor that [moves all the graph related DB code out of the
|
* A code refactor that [moves all the graph related DB code out of the
|
||||||
@ -292,6 +297,7 @@ The underlying functionality between those two options remain the same.
|
|||||||
|
|
||||||
* Abdullahi Yunus
|
* Abdullahi Yunus
|
||||||
* Alex Akselrod
|
* Alex Akselrod
|
||||||
|
* Andras Banki-Horvath
|
||||||
* Animesh Bilthare
|
* Animesh Bilthare
|
||||||
* Boris Nagaev
|
* Boris Nagaev
|
||||||
* Carla Kirk-Cohen
|
* Carla Kirk-Cohen
|
||||||
|
6
go.mod
6
go.mod
@ -138,7 +138,7 @@ require (
|
|||||||
github.com/opencontainers/image-spec v1.0.2 // indirect
|
github.com/opencontainers/image-spec v1.0.2 // indirect
|
||||||
github.com/opencontainers/runc v1.1.12 // indirect
|
github.com/opencontainers/runc v1.1.12 // indirect
|
||||||
github.com/ory/dockertest/v3 v3.10.0 // indirect
|
github.com/ory/dockertest/v3 v3.10.0 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0
|
||||||
github.com/prometheus/client_model v0.2.0 // indirect
|
github.com/prometheus/client_model v0.2.0 // indirect
|
||||||
github.com/prometheus/common v0.26.0 // indirect
|
github.com/prometheus/common v0.26.0 // indirect
|
||||||
github.com/prometheus/procfs v0.6.0 // indirect
|
github.com/prometheus/procfs v0.6.0 // indirect
|
||||||
@ -207,6 +207,10 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2
|
|||||||
// allows us to specify that as an option.
|
// allows us to specify that as an option.
|
||||||
replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display
|
replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display
|
||||||
|
|
||||||
|
// Temporary replace until https://github.com/lightningnetwork/lnd/pull/8831 is
|
||||||
|
// merged.
|
||||||
|
replace github.com/lightningnetwork/lnd/sqldb => ./sqldb
|
||||||
|
|
||||||
// If you change this please also update docs/INSTALL.md and GO_VERSION in
|
// If you change this please also update docs/INSTALL.md and GO_VERSION in
|
||||||
// Makefile (then run `make lint` to see where else it needs to be updated as
|
// Makefile (then run `make lint` to see where else it needs to be updated as
|
||||||
// well).
|
// well).
|
||||||
|
2
go.sum
2
go.sum
@ -464,8 +464,6 @@ github.com/lightningnetwork/lnd/kvdb v1.4.12 h1:Y0WY5Tbjyjn6eCYh068qkWur5oFtioJl
|
|||||||
github.com/lightningnetwork/lnd/kvdb v1.4.12/go.mod h1:hx9buNcxsZpZwh8m1sjTQwy2SOeBoWWOZ3RnOQkMsxI=
|
github.com/lightningnetwork/lnd/kvdb v1.4.12/go.mod h1:hx9buNcxsZpZwh8m1sjTQwy2SOeBoWWOZ3RnOQkMsxI=
|
||||||
github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI=
|
github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI=
|
||||||
github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4=
|
github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4=
|
||||||
github.com/lightningnetwork/lnd/sqldb v1.0.6 h1:LJdDSVdN33bVBIefsaJlPW9PDAm6GrXlyFucmzSJ3Ts=
|
|
||||||
github.com/lightningnetwork/lnd/sqldb v1.0.6/go.mod h1:OG09zL/PHPaBJefp4HsPz2YLUJ+zIQHbpgCtLnOx8I4=
|
|
||||||
github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM=
|
github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM=
|
||||||
github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA=
|
github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA=
|
||||||
github.com/lightningnetwork/lnd/tlv v1.3.0 h1:exS/KCPEgpOgviIttfiXAPaUqw2rHQrnUOpP7HPBPiY=
|
github.com/lightningnetwork/lnd/tlv v1.3.0 h1:exS/KCPEgpOgviIttfiXAPaUqw2rHQrnUOpP7HPBPiY=
|
||||||
|
@ -187,6 +187,11 @@ func (r InvoiceRef) Modifier() RefModifier {
|
|||||||
return r.refModifier
|
return r.refModifier
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsHashOnly returns true if the invoice ref only contains a payment hash.
|
||||||
|
func (r InvoiceRef) IsHashOnly() bool {
|
||||||
|
return r.payHash != nil && r.payAddr == nil && r.setID == nil
|
||||||
|
}
|
||||||
|
|
||||||
// String returns a human-readable representation of an InvoiceRef.
|
// String returns a human-readable representation of an InvoiceRef.
|
||||||
func (r InvoiceRef) String() string {
|
func (r InvoiceRef) String() string {
|
||||||
var ids []string
|
var ids []string
|
||||||
|
203
invoices/kv_sql_migration_test.go
Normal file
203
invoices/kv_sql_migration_test.go
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
package invoices_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
|
invpkg "github.com/lightningnetwork/lnd/invoices"
|
||||||
|
"github.com/lightningnetwork/lnd/kvdb"
|
||||||
|
"github.com/lightningnetwork/lnd/kvdb/sqlbase"
|
||||||
|
"github.com/lightningnetwork/lnd/kvdb/sqlite"
|
||||||
|
"github.com/lightningnetwork/lnd/lncfg"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestMigrationWithChannelDB tests the migration of invoices from a bolt backed
|
||||||
|
// channel.db to a SQL database. Note that this test does not attempt to be a
|
||||||
|
// complete migration test for all invoice types but rather is added as a tool
|
||||||
|
// for developers and users to debug invoice migration issues with an actual
|
||||||
|
// channel.db file.
|
||||||
|
func TestMigrationWithChannelDB(t *testing.T) {
|
||||||
|
// First create a shared Postgres instance so we don't spawn a new
|
||||||
|
// docker container for each test.
|
||||||
|
pgFixture := sqldb.NewTestPgFixture(
|
||||||
|
t, sqldb.DefaultPostgresFixtureLifetime,
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
pgFixture.TearDown(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
makeSQLDB := func(t *testing.T, sqlite bool) (*invpkg.SQLStore,
|
||||||
|
*sqldb.TransactionExecutor[*sqlc.Queries]) {
|
||||||
|
|
||||||
|
var db *sqldb.BaseDB
|
||||||
|
if sqlite {
|
||||||
|
db = sqldb.NewTestSqliteDB(t).BaseDB
|
||||||
|
} else {
|
||||||
|
db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
|
||||||
|
}
|
||||||
|
|
||||||
|
invoiceExecutor := sqldb.NewTransactionExecutor(
|
||||||
|
db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries {
|
||||||
|
return db.WithTx(tx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
genericExecutor := sqldb.NewTransactionExecutor(
|
||||||
|
db, func(tx *sql.Tx) *sqlc.Queries {
|
||||||
|
return db.WithTx(tx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
testClock := clock.NewTestClock(time.Unix(1, 0))
|
||||||
|
|
||||||
|
return invpkg.NewSQLStore(invoiceExecutor, testClock),
|
||||||
|
genericExecutor
|
||||||
|
}
|
||||||
|
|
||||||
|
migrationTest := func(t *testing.T, kvStore *channeldb.DB,
|
||||||
|
sqlite bool) {
|
||||||
|
|
||||||
|
sqlInvoiceStore, sqlStore := makeSQLDB(t, sqlite)
|
||||||
|
ctxb := context.Background()
|
||||||
|
|
||||||
|
const batchSize = 11
|
||||||
|
var opts sqldb.MigrationTxOptions
|
||||||
|
err := sqlStore.ExecTx(
|
||||||
|
ctxb, &opts, func(tx *sqlc.Queries) error {
|
||||||
|
return invpkg.MigrateInvoicesToSQL(
|
||||||
|
ctxb, kvStore.Backend, kvStore, tx,
|
||||||
|
batchSize,
|
||||||
|
)
|
||||||
|
}, func() {},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// MigrateInvoices will check if the inserted invoice equals to
|
||||||
|
// the migrated one, but as a sanity check, we'll also fetch the
|
||||||
|
// invoices from the store and compare them to the original
|
||||||
|
// invoices.
|
||||||
|
query := invpkg.InvoiceQuery{
|
||||||
|
IndexOffset: 0,
|
||||||
|
// As a sanity check, fetch more invoices than we have
|
||||||
|
// to ensure that we did not add any extra invoices.
|
||||||
|
// Note that we don't really have a way to know the
|
||||||
|
// exact number of invoices in the bolt db without first
|
||||||
|
// iterating over all of them, but for test purposes
|
||||||
|
// constant should be enough.
|
||||||
|
NumMaxInvoices: 9999,
|
||||||
|
}
|
||||||
|
result1, err := kvStore.QueryInvoices(ctxb, query)
|
||||||
|
require.NoError(t, err)
|
||||||
|
numInvoices := len(result1.Invoices)
|
||||||
|
|
||||||
|
result2, err := sqlInvoiceStore.QueryInvoices(ctxb, query)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, numInvoices, len(result2.Invoices))
|
||||||
|
|
||||||
|
// Simply zero out the add index so we don't fail on that when
|
||||||
|
// comparing.
|
||||||
|
for i := 0; i < numInvoices; i++ {
|
||||||
|
result1.Invoices[i].AddIndex = 0
|
||||||
|
result2.Invoices[i].AddIndex = 0
|
||||||
|
|
||||||
|
// We need to override the timezone of the invoices as
|
||||||
|
// the provided DB vs the test runners local time zone
|
||||||
|
// might be different.
|
||||||
|
invpkg.OverrideInvoiceTimeZone(&result1.Invoices[i])
|
||||||
|
invpkg.OverrideInvoiceTimeZone(&result2.Invoices[i])
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, result1.Invoices[i], result2.Invoices[i],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dbPath string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"empty",
|
||||||
|
t.TempDir(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testdata",
|
||||||
|
"testdata",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var kvStore *channeldb.DB
|
||||||
|
|
||||||
|
// First check if we have a channel.sqlite file in the
|
||||||
|
// testdata directory. If we do, we'll use that as the
|
||||||
|
// channel db for the migration test.
|
||||||
|
chanDBPath := path.Join(
|
||||||
|
test.dbPath, lncfg.SqliteChannelDBName,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Just some sane defaults for the sqlite config.
|
||||||
|
const (
|
||||||
|
timeout = 5 * time.Second
|
||||||
|
maxConns = 50
|
||||||
|
)
|
||||||
|
|
||||||
|
sqliteConfig := &sqlite.Config{
|
||||||
|
Timeout: timeout,
|
||||||
|
BusyTimeout: timeout,
|
||||||
|
MaxConnections: maxConns,
|
||||||
|
}
|
||||||
|
|
||||||
|
if fileExists(chanDBPath) {
|
||||||
|
sqlbase.Init(maxConns)
|
||||||
|
|
||||||
|
sqliteBackend, err := kvdb.Open(
|
||||||
|
kvdb.SqliteBackendName,
|
||||||
|
context.Background(),
|
||||||
|
sqliteConfig, test.dbPath,
|
||||||
|
lncfg.SqliteChannelDBName,
|
||||||
|
lncfg.NSChannelDB,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
kvStore, err = channeldb.CreateWithBackend(
|
||||||
|
sqliteBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
kvStore = channeldb.OpenForTesting(
|
||||||
|
t, test.dbPath,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Postgres", func(t *testing.T) {
|
||||||
|
migrationTest(t, kvStore, false)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SQLite", func(t *testing.T) {
|
||||||
|
migrationTest(t, kvStore, true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExists(filename string) bool {
|
||||||
|
info, err := os.Stat(filename)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return !info.IsDir()
|
||||||
|
}
|
558
invoices/sql_migration.go
Normal file
558
invoices/sql_migration.go
Normal file
@ -0,0 +1,558 @@
|
|||||||
|
package invoices
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
"github.com/lightningnetwork/lnd/graph/db/models"
|
||||||
|
"github.com/lightningnetwork/lnd/kvdb"
|
||||||
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||||
|
"github.com/pmezard/go-difflib/difflib"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// invoiceBucket is the name of the bucket within the database that
|
||||||
|
// stores all data related to invoices no matter their final state.
|
||||||
|
// Within the invoice bucket, each invoice is keyed by its invoice ID
|
||||||
|
// which is a monotonically increasing uint32.
|
||||||
|
invoiceBucket = []byte("invoices")
|
||||||
|
|
||||||
|
// invoiceIndexBucket is the name of the sub-bucket within the
|
||||||
|
// invoiceBucket which indexes all invoices by their payment hash. The
|
||||||
|
// payment hash is the sha256 of the invoice's payment preimage. This
|
||||||
|
// index is used to detect duplicates, and also to provide a fast path
|
||||||
|
// for looking up incoming HTLCs to determine if we're able to settle
|
||||||
|
// them fully.
|
||||||
|
//
|
||||||
|
// maps: payHash => invoiceKey
|
||||||
|
invoiceIndexBucket = []byte("paymenthashes")
|
||||||
|
|
||||||
|
// numInvoicesKey is the name of key which houses the auto-incrementing
|
||||||
|
// invoice ID which is essentially used as a primary key. With each
|
||||||
|
// invoice inserted, the primary key is incremented by one. This key is
|
||||||
|
// stored within the invoiceIndexBucket. Within the invoiceBucket
|
||||||
|
// invoices are uniquely identified by the invoice ID.
|
||||||
|
numInvoicesKey = []byte("nik")
|
||||||
|
|
||||||
|
// addIndexBucket is an index bucket that we'll use to create a
|
||||||
|
// monotonically increasing set of add indexes. Each time we add a new
|
||||||
|
// invoice, this sequence number will be incremented and then populated
|
||||||
|
// within the new invoice.
|
||||||
|
//
|
||||||
|
// In addition to this sequence number, we map:
|
||||||
|
//
|
||||||
|
// addIndexNo => invoiceKey
|
||||||
|
addIndexBucket = []byte("invoice-add-index")
|
||||||
|
|
||||||
|
// ErrMigrationMismatch is returned when the migrated invoice does not
|
||||||
|
// match the original invoice.
|
||||||
|
ErrMigrationMismatch = fmt.Errorf("migrated invoice does not match " +
|
||||||
|
"original invoice")
|
||||||
|
)
|
||||||
|
|
||||||
|
// createInvoiceHashIndex generates a hash index that contains payment hashes
|
||||||
|
// for each invoice in the database. Retrieving the payment hash for certain
|
||||||
|
// invoices, such as those created for spontaneous AMP payments, can be
|
||||||
|
// challenging because the hash is not directly derivable from the invoice's
|
||||||
|
// parameters and is stored separately in the `paymenthashes` bucket. This
|
||||||
|
// bucket maps payment hashes to invoice keys, but for migration purposes, we
|
||||||
|
// need the ability to query in the reverse direction. This function establishes
|
||||||
|
// a new index in the SQL database that maps each invoice key to its
|
||||||
|
// corresponding payment hash.
|
||||||
|
func createInvoiceHashIndex(ctx context.Context, db kvdb.Backend,
|
||||||
|
tx *sqlc.Queries) error {
|
||||||
|
|
||||||
|
return db.View(func(kvTx kvdb.RTx) error {
|
||||||
|
invoices := kvTx.ReadBucket(invoiceBucket)
|
||||||
|
if invoices == nil {
|
||||||
|
return ErrNoInvoicesCreated
|
||||||
|
}
|
||||||
|
|
||||||
|
invoiceIndex := invoices.NestedReadBucket(
|
||||||
|
invoiceIndexBucket,
|
||||||
|
)
|
||||||
|
if invoiceIndex == nil {
|
||||||
|
return ErrNoInvoicesCreated
|
||||||
|
}
|
||||||
|
|
||||||
|
addIndex := invoices.NestedReadBucket(addIndexBucket)
|
||||||
|
if addIndex == nil {
|
||||||
|
return ErrNoInvoicesCreated
|
||||||
|
}
|
||||||
|
|
||||||
|
// First, iterate over all elements in the add index bucket and
|
||||||
|
// insert the add index value for the corresponding invoice key
|
||||||
|
// in the payment_hashes table.
|
||||||
|
err := addIndex.ForEach(func(k, v []byte) error {
|
||||||
|
// The key is the add index, and the value is
|
||||||
|
// the invoice key.
|
||||||
|
addIndexNo := binary.BigEndian.Uint64(k)
|
||||||
|
invoiceKey := binary.BigEndian.Uint32(v)
|
||||||
|
|
||||||
|
return tx.InsertKVInvoiceKeyAndAddIndex(ctx,
|
||||||
|
sqlc.InsertKVInvoiceKeyAndAddIndexParams{
|
||||||
|
ID: int64(invoiceKey),
|
||||||
|
AddIndex: int64(addIndexNo),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, iterate over all hashes in the invoice index bucket and
|
||||||
|
// set the hash to the corresponding the invoice key in the
|
||||||
|
// payment_hashes table.
|
||||||
|
return invoiceIndex.ForEach(func(k, v []byte) error {
|
||||||
|
// Skip the special numInvoicesKey as that does
|
||||||
|
// not point to a valid invoice.
|
||||||
|
if bytes.Equal(k, numInvoicesKey) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The key is the payment hash, and the value
|
||||||
|
// is the invoice key.
|
||||||
|
if len(k) != lntypes.HashSize {
|
||||||
|
return fmt.Errorf("invalid payment "+
|
||||||
|
"hash length: expected %v, "+
|
||||||
|
"got %v", lntypes.HashSize,
|
||||||
|
len(k))
|
||||||
|
}
|
||||||
|
|
||||||
|
invoiceKey := binary.BigEndian.Uint32(v)
|
||||||
|
|
||||||
|
return tx.SetKVInvoicePaymentHash(ctx,
|
||||||
|
sqlc.SetKVInvoicePaymentHashParams{
|
||||||
|
ID: int64(invoiceKey),
|
||||||
|
Hash: k,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}, func() {})
|
||||||
|
}
|
||||||
|
|
||||||
|
// toInsertMigratedInvoiceParams creates the parameters for inserting a migrated
|
||||||
|
// invoice into the SQL database. The parameters are derived from the original
|
||||||
|
// invoice insert parameters.
|
||||||
|
func toInsertMigratedInvoiceParams(
|
||||||
|
params sqlc.InsertInvoiceParams) sqlc.InsertMigratedInvoiceParams {
|
||||||
|
|
||||||
|
return sqlc.InsertMigratedInvoiceParams{
|
||||||
|
Hash: params.Hash,
|
||||||
|
Preimage: params.Preimage,
|
||||||
|
Memo: params.Memo,
|
||||||
|
AmountMsat: params.AmountMsat,
|
||||||
|
CltvDelta: params.CltvDelta,
|
||||||
|
Expiry: params.Expiry,
|
||||||
|
PaymentAddr: params.PaymentAddr,
|
||||||
|
PaymentRequest: params.PaymentRequest,
|
||||||
|
PaymentRequestHash: params.PaymentRequestHash,
|
||||||
|
State: params.State,
|
||||||
|
AmountPaidMsat: params.AmountPaidMsat,
|
||||||
|
IsAmp: params.IsAmp,
|
||||||
|
IsHodl: params.IsHodl,
|
||||||
|
IsKeysend: params.IsKeysend,
|
||||||
|
CreatedAt: params.CreatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MigrateSingleInvoice migrates a single invoice to the new SQL schema. Note
|
||||||
|
// that perfect equality between the old and new schemas is not achievable, as
|
||||||
|
// the invoice's add index cannot be mapped directly to its ID due to SQL’s
|
||||||
|
// auto-incrementing primary key. The ID returned from the insert will instead
|
||||||
|
// serve as the add index in the new schema.
|
||||||
|
func MigrateSingleInvoice(ctx context.Context, tx SQLInvoiceQueries,
|
||||||
|
invoice *Invoice, paymentHash lntypes.Hash) error {
|
||||||
|
|
||||||
|
insertInvoiceParams, err := makeInsertInvoiceParams(
|
||||||
|
invoice, paymentHash,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the insert invoice parameters to the migrated invoice insert
|
||||||
|
// parameters.
|
||||||
|
insertMigratedInvoiceParams := toInsertMigratedInvoiceParams(
|
||||||
|
insertInvoiceParams,
|
||||||
|
)
|
||||||
|
|
||||||
|
// If the invoice is settled, we'll also set the timestamp and the index
|
||||||
|
// at which it was settled.
|
||||||
|
if invoice.State == ContractSettled {
|
||||||
|
if invoice.SettleIndex == 0 {
|
||||||
|
return fmt.Errorf("settled invoice %s missing settle "+
|
||||||
|
"index", paymentHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
if invoice.SettleDate.IsZero() {
|
||||||
|
return fmt.Errorf("settled invoice %s missing settle "+
|
||||||
|
"date", paymentHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
insertMigratedInvoiceParams.SettleIndex = sqldb.SQLInt64(
|
||||||
|
invoice.SettleIndex,
|
||||||
|
)
|
||||||
|
insertMigratedInvoiceParams.SettledAt = sqldb.SQLTime(
|
||||||
|
invoice.SettleDate.UTC(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First we need to insert the invoice itself so we can use the "add
|
||||||
|
// index" which in this case is the auto incrementing primary key that
|
||||||
|
// is returned from the insert.
|
||||||
|
invoiceID, err := tx.InsertMigratedInvoice(
|
||||||
|
ctx, insertMigratedInvoiceParams,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to insert invoice: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the invoice's features.
|
||||||
|
for feature := range invoice.Terms.Features.Features() {
|
||||||
|
params := sqlc.InsertInvoiceFeatureParams{
|
||||||
|
InvoiceID: invoiceID,
|
||||||
|
Feature: int32(feature),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tx.InsertInvoiceFeature(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to insert invoice "+
|
||||||
|
"feature(%v): %w", feature, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlHtlcIDs := make(map[models.CircuitKey]int64)
|
||||||
|
|
||||||
|
// Now insert the HTLCs of the invoice. We'll also keep track of the SQL
|
||||||
|
// ID of each HTLC so we can use it when inserting the AMP sub invoices.
|
||||||
|
for circuitKey, htlc := range invoice.Htlcs {
|
||||||
|
htlcParams := sqlc.InsertInvoiceHTLCParams{
|
||||||
|
HtlcID: int64(circuitKey.HtlcID),
|
||||||
|
ChanID: strconv.FormatUint(
|
||||||
|
circuitKey.ChanID.ToUint64(), 10,
|
||||||
|
),
|
||||||
|
AmountMsat: int64(htlc.Amt),
|
||||||
|
AcceptHeight: int32(htlc.AcceptHeight),
|
||||||
|
AcceptTime: htlc.AcceptTime.UTC(),
|
||||||
|
ExpiryHeight: int32(htlc.Expiry),
|
||||||
|
State: int16(htlc.State),
|
||||||
|
InvoiceID: invoiceID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Leave the MPP amount as NULL if the MPP total amount is zero.
|
||||||
|
if htlc.MppTotalAmt != 0 {
|
||||||
|
htlcParams.TotalMppMsat = sqldb.SQLInt64(
|
||||||
|
int64(htlc.MppTotalAmt),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Leave the resolve time as NULL if the HTLC is not resolved.
|
||||||
|
if !htlc.ResolveTime.IsZero() {
|
||||||
|
htlcParams.ResolveTime = sqldb.SQLTime(
|
||||||
|
htlc.ResolveTime.UTC(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlID, err := tx.InsertInvoiceHTLC(ctx, htlcParams)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to insert invoice htlc: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlHtlcIDs[circuitKey] = sqlID
|
||||||
|
|
||||||
|
// Store custom records.
|
||||||
|
for key, value := range htlc.CustomRecords {
|
||||||
|
err = tx.InsertInvoiceHTLCCustomRecord(
|
||||||
|
ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
|
||||||
|
Key: int64(key),
|
||||||
|
Value: value,
|
||||||
|
HtlcID: sqlID,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !invoice.IsAMP() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for setID, ampState := range invoice.AMPState {
|
||||||
|
// Find the earliest HTLC of the AMP invoice, which will
|
||||||
|
// be used as the creation date of this sub invoice.
|
||||||
|
var createdAt time.Time
|
||||||
|
for circuitKey := range ampState.InvoiceKeys {
|
||||||
|
htlc := invoice.Htlcs[circuitKey]
|
||||||
|
if createdAt.IsZero() {
|
||||||
|
createdAt = htlc.AcceptTime.UTC()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if createdAt.After(htlc.AcceptTime) {
|
||||||
|
createdAt = htlc.AcceptTime.UTC()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params := sqlc.InsertAMPSubInvoiceParams{
|
||||||
|
SetID: setID[:],
|
||||||
|
State: int16(ampState.State),
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
InvoiceID: invoiceID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ampState.SettleIndex != 0 {
|
||||||
|
if ampState.SettleDate.IsZero() {
|
||||||
|
return fmt.Errorf("settled AMP sub invoice %x "+
|
||||||
|
"missing settle date", setID)
|
||||||
|
}
|
||||||
|
|
||||||
|
params.SettledAt = sqldb.SQLTime(
|
||||||
|
ampState.SettleDate.UTC(),
|
||||||
|
)
|
||||||
|
|
||||||
|
params.SettleIndex = sqldb.SQLInt64(
|
||||||
|
ampState.SettleIndex,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tx.InsertAMPSubInvoice(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to insert AMP sub invoice: "+
|
||||||
|
"%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we can add the AMP HTLCs to the database.
|
||||||
|
for circuitKey := range ampState.InvoiceKeys {
|
||||||
|
htlc := invoice.Htlcs[circuitKey]
|
||||||
|
rootShare := htlc.AMP.Record.RootShare()
|
||||||
|
|
||||||
|
sqlHtlcID, ok := sqlHtlcIDs[circuitKey]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("missing htlc for AMP htlc: "+
|
||||||
|
"%v", circuitKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := sqlc.InsertAMPSubInvoiceHTLCParams{
|
||||||
|
InvoiceID: invoiceID,
|
||||||
|
SetID: setID[:],
|
||||||
|
HtlcID: sqlHtlcID,
|
||||||
|
RootShare: rootShare[:],
|
||||||
|
ChildIndex: int64(htlc.AMP.Record.ChildIndex()),
|
||||||
|
Hash: htlc.AMP.Hash[:],
|
||||||
|
}
|
||||||
|
|
||||||
|
if htlc.AMP.Preimage != nil {
|
||||||
|
params.Preimage = htlc.AMP.Preimage[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.InsertAMPSubInvoiceHTLC(ctx, params)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to insert AMP sub "+
|
||||||
|
"invoice: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OverrideInvoiceTimeZone overrides the time zone of the invoice to the local
|
||||||
|
// time zone and chops off the nanosecond part for comparison. This is needed
|
||||||
|
// because KV database stores times as-is which as an unwanted side effect would
|
||||||
|
// fail migration due to time comparison expecting both the original and
|
||||||
|
// migrated invoices to be in the same local time zone and in microsecond
|
||||||
|
// precision. Note that PostgreSQL stores times in microsecond precision while
|
||||||
|
// SQLite can store times in nanosecond precision if using TEXT storage class.
|
||||||
|
func OverrideInvoiceTimeZone(invoice *Invoice) {
|
||||||
|
fixTime := func(t time.Time) time.Time {
|
||||||
|
return t.In(time.Local).Truncate(time.Microsecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
invoice.CreationDate = fixTime(invoice.CreationDate)
|
||||||
|
|
||||||
|
if !invoice.SettleDate.IsZero() {
|
||||||
|
invoice.SettleDate = fixTime(invoice.SettleDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
if invoice.IsAMP() {
|
||||||
|
for setID, ampState := range invoice.AMPState {
|
||||||
|
if ampState.SettleDate.IsZero() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ampState.SettleDate = fixTime(ampState.SettleDate)
|
||||||
|
invoice.AMPState[setID] = ampState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, htlc := range invoice.Htlcs {
|
||||||
|
if !htlc.AcceptTime.IsZero() {
|
||||||
|
htlc.AcceptTime = fixTime(htlc.AcceptTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !htlc.ResolveTime.IsZero() {
|
||||||
|
htlc.ResolveTime = fixTime(htlc.ResolveTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MigrateInvoicesToSQL runs the migration of all invoices from the KV database
|
||||||
|
// to the SQL database. The migration is done in a single transaction to ensure
|
||||||
|
// that all invoices are migrated or none at all. This function can be run
|
||||||
|
// multiple times without causing any issues as it will check if the migration
|
||||||
|
// has already been performed.
|
||||||
|
func MigrateInvoicesToSQL(ctx context.Context, db kvdb.Backend,
|
||||||
|
kvStore InvoiceDB, tx *sqlc.Queries, batchSize int) error {
|
||||||
|
|
||||||
|
log.Infof("Starting migration of invoices from KV to SQL")
|
||||||
|
|
||||||
|
offset := uint64(0)
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
// Create the hash index which we will use to look up invoice
|
||||||
|
// payment hashes by their add index during migration.
|
||||||
|
err := createInvoiceHashIndex(ctx, db, tx)
|
||||||
|
if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
|
||||||
|
log.Errorf("Unable to create invoice hash index: %v",
|
||||||
|
err)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debugf("Created SQL invoice hash index in %v", time.Since(t0))
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
// Now we can start migrating the invoices. We'll do this in
|
||||||
|
// batches to reduce memory usage.
|
||||||
|
for {
|
||||||
|
t0 = time.Now()
|
||||||
|
query := InvoiceQuery{
|
||||||
|
IndexOffset: offset,
|
||||||
|
NumMaxInvoices: uint64(batchSize),
|
||||||
|
}
|
||||||
|
|
||||||
|
queryResult, err := kvStore.QueryInvoices(ctx, query)
|
||||||
|
if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
|
||||||
|
return fmt.Errorf("unable to query invoices: "+
|
||||||
|
"%w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(queryResult.Invoices) == 0 {
|
||||||
|
log.Infof("All invoices migrated")
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
err = migrateInvoices(ctx, tx, queryResult.Invoices)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
offset = queryResult.LastIndexOffset
|
||||||
|
total += len(queryResult.Invoices)
|
||||||
|
log.Debugf("Migrated %d KV invoices to SQL in %v\n", total,
|
||||||
|
time.Since(t0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up the hash index as it's no longer needed.
|
||||||
|
err = tx.ClearKVInvoiceHashIndex(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to clear invoice hash "+
|
||||||
|
"index: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Migration of %d invoices from KV to SQL completed", total)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateInvoices(ctx context.Context, tx *sqlc.Queries,
|
||||||
|
invoices []Invoice) error {
|
||||||
|
|
||||||
|
for i, invoice := range invoices {
|
||||||
|
var paymentHash lntypes.Hash
|
||||||
|
if invoice.Terms.PaymentPreimage != nil {
|
||||||
|
paymentHash = invoice.Terms.PaymentPreimage.Hash()
|
||||||
|
} else {
|
||||||
|
paymentHashBytes, err :=
|
||||||
|
tx.GetKVInvoicePaymentHashByAddIndex(
|
||||||
|
ctx, int64(invoice.AddIndex),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
// This would be an unexpected inconsistency
|
||||||
|
// in the kv database. We can't do much here
|
||||||
|
// so we'll notify the user and continue.
|
||||||
|
log.Warnf("Cannot migrate invoice, unable to "+
|
||||||
|
"fetch payment hash (add_index=%v): %v",
|
||||||
|
invoice.AddIndex, err)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(paymentHash[:], paymentHashBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := MigrateSingleInvoice(ctx, tx, &invoices[i], paymentHash)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to migrate invoice(%v): %w",
|
||||||
|
paymentHash, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
migratedInvoice, err := fetchInvoice(
|
||||||
|
ctx, tx, InvoiceRefByHash(paymentHash),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to fetch migrated "+
|
||||||
|
"invoice(%v): %w", paymentHash, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override the time zone for comparison. Note that we need to
|
||||||
|
// override both invoices as the original invoice is coming from
|
||||||
|
// KV database, it was stored as a binary serialized Go
|
||||||
|
// time.Time value which has nanosecond precision but might have
|
||||||
|
// been created in a different time zone. The migrated invoice
|
||||||
|
// is stored in SQL in UTC and selected in the local time zone,
|
||||||
|
// however in PostgreSQL it has microsecond precision while in
|
||||||
|
// SQLite it has nanosecond precision if using TEXT storage
|
||||||
|
// class.
|
||||||
|
OverrideInvoiceTimeZone(&invoice)
|
||||||
|
OverrideInvoiceTimeZone(migratedInvoice)
|
||||||
|
|
||||||
|
// Override the add index before checking for equality.
|
||||||
|
migratedInvoice.AddIndex = invoice.AddIndex
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(invoice, *migratedInvoice) {
|
||||||
|
diff := difflib.UnifiedDiff{
|
||||||
|
A: difflib.SplitLines(
|
||||||
|
spew.Sdump(invoice),
|
||||||
|
),
|
||||||
|
B: difflib.SplitLines(
|
||||||
|
spew.Sdump(migratedInvoice),
|
||||||
|
),
|
||||||
|
FromFile: "Expected",
|
||||||
|
FromDate: "",
|
||||||
|
ToFile: "Actual",
|
||||||
|
ToDate: "",
|
||||||
|
Context: 3,
|
||||||
|
}
|
||||||
|
diffText, _ := difflib.GetUnifiedDiffString(diff)
|
||||||
|
|
||||||
|
return fmt.Errorf("%w: %v.\n%v", ErrMigrationMismatch,
|
||||||
|
paymentHash, diffText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
421
invoices/sql_migration_test.go
Normal file
421
invoices/sql_migration_test.go
Normal file
@ -0,0 +1,421 @@
|
|||||||
|
package invoices
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"database/sql"
|
||||||
|
"math/rand"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
|
"github.com/lightningnetwork/lnd/graph/db/models"
|
||||||
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/record"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"pgregory.net/rapid"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// testHtlcIDSequence is a global counter for generating unique HTLC
|
||||||
|
// IDs.
|
||||||
|
testHtlcIDSequence uint64
|
||||||
|
)
|
||||||
|
|
||||||
|
// randomString generates a random string of a given length using rapid.
|
||||||
|
func randomStringRapid(t *rapid.T, length int) string {
|
||||||
|
// Define the character set for the string.
|
||||||
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" //nolint:ll
|
||||||
|
|
||||||
|
// Generate a string by selecting random characters from the charset.
|
||||||
|
runes := make([]rune, length)
|
||||||
|
for i := range runes {
|
||||||
|
// Draw a random index and use it to select a character from the
|
||||||
|
// charset.
|
||||||
|
index := rapid.IntRange(0, len(charset)-1).Draw(t, "charIndex")
|
||||||
|
runes[i] = rune(charset[index])
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(runes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// randTimeBetween generates a random time between min and max.
|
||||||
|
func randTimeBetween(min, max time.Time) time.Time {
|
||||||
|
var timeZones = []*time.Location{
|
||||||
|
time.UTC,
|
||||||
|
time.FixedZone("EST", -5*3600),
|
||||||
|
time.FixedZone("MST", -7*3600),
|
||||||
|
time.FixedZone("PST", -8*3600),
|
||||||
|
time.FixedZone("CEST", 2*3600),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure max is after min
|
||||||
|
if max.Before(min) {
|
||||||
|
min, max = max, min
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the range in nanoseconds
|
||||||
|
duration := max.Sub(min)
|
||||||
|
randDuration := time.Duration(rand.Int63n(duration.Nanoseconds()))
|
||||||
|
|
||||||
|
// Generate the random time
|
||||||
|
randomTime := min.Add(randDuration)
|
||||||
|
|
||||||
|
// Assign a random time zone
|
||||||
|
randomTimeZone := timeZones[rand.Intn(len(timeZones))]
|
||||||
|
|
||||||
|
// Return the time in the random time zone
|
||||||
|
return randomTime.In(randomTimeZone)
|
||||||
|
}
|
||||||
|
|
||||||
|
// randTime generates a random time between 2009 and 2140.
|
||||||
|
func randTime() time.Time {
|
||||||
|
min := time.Date(2009, 1, 3, 0, 0, 0, 0, time.UTC)
|
||||||
|
max := time.Date(2140, 1, 1, 0, 0, 0, 1000, time.UTC)
|
||||||
|
|
||||||
|
return randTimeBetween(min, max)
|
||||||
|
}
|
||||||
|
|
||||||
|
func randInvoiceTime(invoice *Invoice) time.Time {
|
||||||
|
return randTimeBetween(
|
||||||
|
invoice.CreationDate,
|
||||||
|
invoice.CreationDate.Add(invoice.Terms.Expiry),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// randHTLCRapid generates a random HTLC for an invoice using rapid to randomize
|
||||||
|
// its parameters.
|
||||||
|
func randHTLCRapid(t *rapid.T, invoice *Invoice, amt lnwire.MilliSatoshi) (
|
||||||
|
models.CircuitKey, *InvoiceHTLC) {
|
||||||
|
|
||||||
|
htlc := &InvoiceHTLC{
|
||||||
|
Amt: amt,
|
||||||
|
AcceptHeight: rapid.Uint32Range(1, 999).Draw(t, "AcceptHeight"),
|
||||||
|
AcceptTime: randInvoiceTime(invoice),
|
||||||
|
Expiry: rapid.Uint32Range(1, 999).Draw(t, "Expiry"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set MPP total amount if MPP feature is enabled in the invoice.
|
||||||
|
if invoice.Terms.Features.HasFeature(lnwire.MPPRequired) {
|
||||||
|
htlc.MppTotalAmt = invoice.Terms.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the HTLC state and resolve time based on the invoice state.
|
||||||
|
switch invoice.State {
|
||||||
|
case ContractSettled:
|
||||||
|
htlc.State = HtlcStateSettled
|
||||||
|
htlc.ResolveTime = randInvoiceTime(invoice)
|
||||||
|
|
||||||
|
case ContractCanceled:
|
||||||
|
htlc.State = HtlcStateCanceled
|
||||||
|
htlc.ResolveTime = randInvoiceTime(invoice)
|
||||||
|
|
||||||
|
case ContractAccepted:
|
||||||
|
htlc.State = HtlcStateAccepted
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add randomized custom records to the HTLC.
|
||||||
|
htlc.CustomRecords = make(record.CustomSet)
|
||||||
|
numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords")
|
||||||
|
for i := 0; i < numRecords; i++ {
|
||||||
|
key := rapid.Uint64Range(
|
||||||
|
record.CustomTypeStart, 1000+record.CustomTypeStart,
|
||||||
|
).Draw(t, "customRecordKey")
|
||||||
|
value := []byte(randomStringRapid(t, 10))
|
||||||
|
htlc.CustomRecords[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a unique HTLC ID and assign it to a channel ID.
|
||||||
|
htlcID := atomic.AddUint64(&testHtlcIDSequence, 1)
|
||||||
|
randChanID := lnwire.NewShortChanIDFromInt(htlcID % 5)
|
||||||
|
|
||||||
|
circuitKey := models.CircuitKey{
|
||||||
|
ChanID: randChanID,
|
||||||
|
HtlcID: htlcID,
|
||||||
|
}
|
||||||
|
|
||||||
|
return circuitKey, htlc
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateInvoiceHTLCsRapid generates all HTLCs for an invoice, including AMP
|
||||||
|
// HTLCs if applicable, using rapid for randomization of HTLC count and
|
||||||
|
// distribution.
|
||||||
|
func generateInvoiceHTLCsRapid(t *rapid.T, invoice *Invoice) {
|
||||||
|
mpp := invoice.Terms.Features.HasFeature(lnwire.MPPRequired)
|
||||||
|
|
||||||
|
// Use rapid to determine the number of HTLCs based on invoice state and
|
||||||
|
// MPP feature.
|
||||||
|
numHTLCs := 1
|
||||||
|
if invoice.State == ContractOpen {
|
||||||
|
numHTLCs = 0
|
||||||
|
} else if mpp {
|
||||||
|
numHTLCs = rapid.IntRange(1, 10).Draw(t, "numHTLCs")
|
||||||
|
}
|
||||||
|
|
||||||
|
total := invoice.Terms.Value
|
||||||
|
|
||||||
|
// Distribute the total amount across the HTLCs, adding any remainder to
|
||||||
|
// the last HTLC.
|
||||||
|
if numHTLCs > 0 {
|
||||||
|
amt := total / lnwire.MilliSatoshi(numHTLCs)
|
||||||
|
remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
|
||||||
|
|
||||||
|
for i := 0; i < numHTLCs; i++ {
|
||||||
|
if i == numHTLCs-1 {
|
||||||
|
// Add remainder to the last HTLC.
|
||||||
|
amt += remainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate an HTLC with a random circuit key and add it
|
||||||
|
// to the invoice.
|
||||||
|
circuitKey, htlc := randHTLCRapid(t, invoice, amt)
|
||||||
|
invoice.Htlcs[circuitKey] = htlc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateAMPHtlcsRapid generates AMP HTLCs for an invoice using rapid to
|
||||||
|
// randomize various parameters of the HTLCs in the AMP set.
|
||||||
|
func generateAMPHtlcsRapid(t *rapid.T, invoice *Invoice) {
|
||||||
|
// Randomly determine the number of AMP sets (1 to 5).
|
||||||
|
numSetIDs := rapid.IntRange(1, 5).Draw(t, "numSetIDs")
|
||||||
|
settledIdx := uint64(1)
|
||||||
|
|
||||||
|
for i := 0; i < numSetIDs; i++ {
|
||||||
|
var setID SetID
|
||||||
|
_, err := crand.Read(setID[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Determine the number of HTLCs in this set (1 to 5).
|
||||||
|
numHTLCs := rapid.IntRange(1, 5).Draw(t, "numHTLCs")
|
||||||
|
total := invoice.Terms.Value
|
||||||
|
invoiceKeys := make(map[CircuitKey]struct{})
|
||||||
|
|
||||||
|
// Calculate the amount per HTLC and account for remainder in
|
||||||
|
// the final HTLC.
|
||||||
|
amt := total / lnwire.MilliSatoshi(numHTLCs)
|
||||||
|
remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
|
||||||
|
|
||||||
|
var htlcState HtlcState
|
||||||
|
for j := 0; j < numHTLCs; j++ {
|
||||||
|
if j == numHTLCs-1 {
|
||||||
|
amt += remainder
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate HTLC with randomized parameters.
|
||||||
|
circuitKey, htlc := randHTLCRapid(t, invoice, amt)
|
||||||
|
htlcState = htlc.State
|
||||||
|
|
||||||
|
var (
|
||||||
|
rootShare, hash [32]byte
|
||||||
|
preimage lntypes.Preimage
|
||||||
|
)
|
||||||
|
|
||||||
|
// Randomize AMP data fields.
|
||||||
|
_, err := crand.Read(rootShare[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = crand.Read(hash[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = crand.Read(preimage[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
record := record.NewAMP(rootShare, setID, uint32(j))
|
||||||
|
|
||||||
|
htlc.AMP = &InvoiceHtlcAMPData{
|
||||||
|
Record: *record,
|
||||||
|
Hash: hash,
|
||||||
|
Preimage: &preimage,
|
||||||
|
}
|
||||||
|
|
||||||
|
invoice.Htlcs[circuitKey] = htlc
|
||||||
|
invoiceKeys[circuitKey] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ampState := InvoiceStateAMP{
|
||||||
|
State: htlcState,
|
||||||
|
InvoiceKeys: invoiceKeys,
|
||||||
|
}
|
||||||
|
if htlcState == HtlcStateSettled {
|
||||||
|
ampState.SettleIndex = settledIdx
|
||||||
|
ampState.SettleDate = randInvoiceTime(invoice)
|
||||||
|
settledIdx++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the total amount paid if the AMP set is not canceled.
|
||||||
|
if htlcState != HtlcStateCanceled {
|
||||||
|
ampState.AmtPaid = invoice.Terms.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
invoice.AMPState[setID] = ampState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMigrateSingleInvoiceRapid tests the migration of single invoices with
|
||||||
|
// random data variations using rapid. This test generates a random invoice
|
||||||
|
// configuration and ensures successful migration.
|
||||||
|
//
|
||||||
|
// NOTE: This test may need to be changed if the Invoice or any of the related
|
||||||
|
// types are modified.
|
||||||
|
func TestMigrateSingleInvoiceRapid(t *testing.T) {
|
||||||
|
// Create a shared Postgres instance for efficient testing.
|
||||||
|
pgFixture := sqldb.NewTestPgFixture(
|
||||||
|
t, sqldb.DefaultPostgresFixtureLifetime,
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
pgFixture.TearDown(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
makeSQLDB := func(t *testing.T, sqlite bool) *SQLStore {
|
||||||
|
var db *sqldb.BaseDB
|
||||||
|
if sqlite {
|
||||||
|
db = sqldb.NewTestSqliteDB(t).BaseDB
|
||||||
|
} else {
|
||||||
|
db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
|
||||||
|
}
|
||||||
|
|
||||||
|
executor := sqldb.NewTransactionExecutor(
|
||||||
|
db, func(tx *sql.Tx) SQLInvoiceQueries {
|
||||||
|
return db.WithTx(tx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
testClock := clock.NewTestClock(time.Unix(1, 0))
|
||||||
|
|
||||||
|
return NewSQLStore(executor, testClock)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define property-based test using rapid.
|
||||||
|
rapid.Check(t, func(rt *rapid.T) {
|
||||||
|
// Randomized feature flags for MPP and AMP.
|
||||||
|
mpp := rapid.Bool().Draw(rt, "mpp")
|
||||||
|
amp := rapid.Bool().Draw(rt, "amp")
|
||||||
|
|
||||||
|
for _, sqlite := range []bool{true, false} {
|
||||||
|
store := makeSQLDB(t, sqlite)
|
||||||
|
testMigrateSingleInvoiceRapid(rt, store, mpp, amp)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// testMigrateSingleInvoiceRapid is the primary function for the migration of a
|
||||||
|
// single invoice with random data in a rapid-based test setup.
|
||||||
|
func testMigrateSingleInvoiceRapid(t *rapid.T, store *SQLStore, mpp bool,
|
||||||
|
amp bool) {
|
||||||
|
|
||||||
|
ctxb := context.Background()
|
||||||
|
invoices := make(map[lntypes.Hash]*Invoice)
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
invoice := generateTestInvoiceRapid(t, mpp, amp)
|
||||||
|
var hash lntypes.Hash
|
||||||
|
_, err := crand.Read(hash[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
invoices[hash] = invoice
|
||||||
|
}
|
||||||
|
|
||||||
|
var ops SQLInvoiceQueriesTxOptions
|
||||||
|
err := store.db.ExecTx(ctxb, &ops, func(tx SQLInvoiceQueries) error {
|
||||||
|
for hash, invoice := range invoices {
|
||||||
|
err := MigrateSingleInvoice(ctxb, tx, invoice, hash)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, func() {})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Fetch and compare each migrated invoice from the store with the
|
||||||
|
// original.
|
||||||
|
for hash, invoice := range invoices {
|
||||||
|
sqlInvoice, err := store.LookupInvoice(
|
||||||
|
ctxb, InvoiceRefByHash(hash),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
invoice.AddIndex = sqlInvoice.AddIndex
|
||||||
|
|
||||||
|
OverrideInvoiceTimeZone(invoice)
|
||||||
|
OverrideInvoiceTimeZone(&sqlInvoice)
|
||||||
|
|
||||||
|
require.Equal(t, *invoice, sqlInvoice)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTestInvoiceRapid generates a random invoice with variations based on
|
||||||
|
// mpp and amp flags.
|
||||||
|
func generateTestInvoiceRapid(t *rapid.T, mpp bool, amp bool) *Invoice {
|
||||||
|
var preimage lntypes.Preimage
|
||||||
|
_, err := crand.Read(preimage[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
terms := ContractTerm{
|
||||||
|
FinalCltvDelta: rapid.Int32Range(1, 1000).Draw(
|
||||||
|
t, "FinalCltvDelta",
|
||||||
|
),
|
||||||
|
Expiry: time.Duration(
|
||||||
|
rapid.IntRange(1, 4444).Draw(t, "Expiry"),
|
||||||
|
) * time.Minute,
|
||||||
|
PaymentPreimage: &preimage,
|
||||||
|
Value: lnwire.MilliSatoshi(
|
||||||
|
rapid.Int64Range(1, 9999999).Draw(t, "Value"),
|
||||||
|
),
|
||||||
|
PaymentAddr: [32]byte{},
|
||||||
|
Features: lnwire.EmptyFeatureVector(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if amp {
|
||||||
|
terms.Features.Set(lnwire.AMPRequired)
|
||||||
|
} else if mpp {
|
||||||
|
terms.Features.Set(lnwire.MPPRequired)
|
||||||
|
}
|
||||||
|
|
||||||
|
created := randTime()
|
||||||
|
|
||||||
|
const maxContractState = 3
|
||||||
|
state := ContractState(
|
||||||
|
rapid.IntRange(0, maxContractState).Draw(t, "ContractState"),
|
||||||
|
)
|
||||||
|
var (
|
||||||
|
settled time.Time
|
||||||
|
settleIndex uint64
|
||||||
|
)
|
||||||
|
if state == ContractSettled {
|
||||||
|
settled = randTimeBetween(created, created.Add(terms.Expiry))
|
||||||
|
settleIndex = rapid.Uint64Range(1, 999).Draw(t, "SettleIndex")
|
||||||
|
}
|
||||||
|
|
||||||
|
invoice := &Invoice{
|
||||||
|
Memo: []byte(randomStringRapid(t, 10)),
|
||||||
|
PaymentRequest: []byte(
|
||||||
|
randomStringRapid(t, MaxPaymentRequestSize),
|
||||||
|
),
|
||||||
|
CreationDate: created,
|
||||||
|
SettleDate: settled,
|
||||||
|
Terms: terms,
|
||||||
|
AddIndex: 0,
|
||||||
|
SettleIndex: settleIndex,
|
||||||
|
State: state,
|
||||||
|
AMPState: make(map[SetID]InvoiceStateAMP),
|
||||||
|
HodlInvoice: rapid.Bool().Draw(t, "HodlInvoice"),
|
||||||
|
}
|
||||||
|
|
||||||
|
invoice.Htlcs = make(map[models.CircuitKey]*InvoiceHTLC)
|
||||||
|
|
||||||
|
if invoice.IsAMP() {
|
||||||
|
generateAMPHtlcsRapid(t, invoice)
|
||||||
|
} else {
|
||||||
|
generateInvoiceHTLCsRapid(t, invoice)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, htlc := range invoice.Htlcs {
|
||||||
|
if htlc.State == HtlcStateSettled {
|
||||||
|
invoice.AmtPaid += htlc.Amt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return invoice
|
||||||
|
}
|
@ -32,6 +32,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
|||||||
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
|
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
|
||||||
error)
|
error)
|
||||||
|
|
||||||
|
// TODO(bhandras): remove this once migrations have been separated out.
|
||||||
|
InsertMigratedInvoice(ctx context.Context,
|
||||||
|
arg sqlc.InsertMigratedInvoiceParams) (int64, error)
|
||||||
|
|
||||||
InsertInvoiceFeature(ctx context.Context,
|
InsertInvoiceFeature(ctx context.Context,
|
||||||
arg sqlc.InsertInvoiceFeatureParams) error
|
arg sqlc.InsertInvoiceFeatureParams) error
|
||||||
|
|
||||||
@ -47,6 +51,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
|||||||
GetInvoice(ctx context.Context,
|
GetInvoice(ctx context.Context,
|
||||||
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
|
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
|
||||||
|
|
||||||
|
GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
|
||||||
|
error)
|
||||||
|
|
||||||
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
|
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
|
||||||
error)
|
error)
|
||||||
|
|
||||||
@ -79,6 +86,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
|||||||
UpsertAMPSubInvoice(ctx context.Context,
|
UpsertAMPSubInvoice(ctx context.Context,
|
||||||
arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
|
arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
|
||||||
|
|
||||||
|
// TODO(bhandras): remove this once migrations have been separated out.
|
||||||
|
InsertAMPSubInvoice(ctx context.Context,
|
||||||
|
arg sqlc.InsertAMPSubInvoiceParams) error
|
||||||
|
|
||||||
UpdateAMPSubInvoiceState(ctx context.Context,
|
UpdateAMPSubInvoiceState(ctx context.Context,
|
||||||
arg sqlc.UpdateAMPSubInvoiceStateParams) error
|
arg sqlc.UpdateAMPSubInvoiceStateParams) error
|
||||||
|
|
||||||
@ -119,6 +130,19 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
|||||||
|
|
||||||
OnAMPSubInvoiceSettled(ctx context.Context,
|
OnAMPSubInvoiceSettled(ctx context.Context,
|
||||||
arg sqlc.OnAMPSubInvoiceSettledParams) error
|
arg sqlc.OnAMPSubInvoiceSettledParams) error
|
||||||
|
|
||||||
|
// Migration specific methods.
|
||||||
|
// TODO(bhandras): remove this once migrations have been separated out.
|
||||||
|
InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
|
||||||
|
arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
|
||||||
|
|
||||||
|
SetKVInvoicePaymentHash(ctx context.Context,
|
||||||
|
arg sqlc.SetKVInvoicePaymentHashParams) error
|
||||||
|
|
||||||
|
GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
|
||||||
|
[]byte, error)
|
||||||
|
|
||||||
|
ClearKVInvoiceHashIndex(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ InvoiceDB = (*SQLStore)(nil)
|
var _ InvoiceDB = (*SQLStore)(nil)
|
||||||
@ -200,6 +224,66 @@ func NewSQLStore(db BatchedSQLInvoiceQueries,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
|
||||||
|
sqlc.InsertInvoiceParams, error) {
|
||||||
|
|
||||||
|
// Precompute the payment request hash so we can use it in the query.
|
||||||
|
var paymentRequestHash []byte
|
||||||
|
if len(invoice.PaymentRequest) > 0 {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write(invoice.PaymentRequest)
|
||||||
|
paymentRequestHash = h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := sqlc.InsertInvoiceParams{
|
||||||
|
Hash: paymentHash[:],
|
||||||
|
AmountMsat: int64(invoice.Terms.Value),
|
||||||
|
CltvDelta: sqldb.SQLInt32(
|
||||||
|
invoice.Terms.FinalCltvDelta,
|
||||||
|
),
|
||||||
|
Expiry: int32(invoice.Terms.Expiry.Seconds()),
|
||||||
|
// Note: keysend invoices don't have a payment request.
|
||||||
|
PaymentRequest: sqldb.SQLStr(string(
|
||||||
|
invoice.PaymentRequest),
|
||||||
|
),
|
||||||
|
PaymentRequestHash: paymentRequestHash,
|
||||||
|
State: int16(invoice.State),
|
||||||
|
AmountPaidMsat: int64(invoice.AmtPaid),
|
||||||
|
IsAmp: invoice.IsAMP(),
|
||||||
|
IsHodl: invoice.HodlInvoice,
|
||||||
|
IsKeysend: invoice.IsKeysend(),
|
||||||
|
CreatedAt: invoice.CreationDate.UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if invoice.Memo != nil {
|
||||||
|
// Store the memo as a nullable string in the database. Note
|
||||||
|
// that for compatibility reasons, we store the value as a valid
|
||||||
|
// string even if it's empty.
|
||||||
|
params.Memo = sql.NullString{
|
||||||
|
String: string(invoice.Memo),
|
||||||
|
Valid: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some invoices may not have a preimage, like in the case of HODL
|
||||||
|
// invoices.
|
||||||
|
if invoice.Terms.PaymentPreimage != nil {
|
||||||
|
preimage := *invoice.Terms.PaymentPreimage
|
||||||
|
if preimage == UnknownPreimage {
|
||||||
|
return sqlc.InsertInvoiceParams{},
|
||||||
|
errors.New("cannot use all-zeroes preimage")
|
||||||
|
}
|
||||||
|
params.Preimage = preimage[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some non MPP payments may have the default (invalid) value.
|
||||||
|
if invoice.Terms.PaymentAddr != BlankPayAddr {
|
||||||
|
params.PaymentAddr = invoice.Terms.PaymentAddr[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddInvoice inserts the targeted invoice into the database. If the invoice has
|
// AddInvoice inserts the targeted invoice into the database. If the invoice has
|
||||||
// *any* payment hashes which already exists within the database, then the
|
// *any* payment hashes which already exists within the database, then the
|
||||||
// insertion will be aborted and rejected due to the strict policy banning any
|
// insertion will be aborted and rejected due to the strict policy banning any
|
||||||
@ -220,55 +304,16 @@ func (i *SQLStore) AddInvoice(ctx context.Context,
|
|||||||
invoiceID int64
|
invoiceID int64
|
||||||
)
|
)
|
||||||
|
|
||||||
// Precompute the payment request hash so we can use it in the query.
|
insertInvoiceParams, err := makeInsertInvoiceParams(
|
||||||
var paymentRequestHash []byte
|
newInvoice, paymentHash,
|
||||||
if len(newInvoice.PaymentRequest) > 0 {
|
)
|
||||||
h := sha256.New()
|
if err != nil {
|
||||||
h.Write(newInvoice.PaymentRequest)
|
return 0, err
|
||||||
paymentRequestHash = h.Sum(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
|
|
||||||
params := sqlc.InsertInvoiceParams{
|
|
||||||
Hash: paymentHash[:],
|
|
||||||
Memo: sqldb.SQLStr(string(newInvoice.Memo)),
|
|
||||||
AmountMsat: int64(newInvoice.Terms.Value),
|
|
||||||
// Note: BOLT12 invoices don't have a final cltv delta.
|
|
||||||
CltvDelta: sqldb.SQLInt32(
|
|
||||||
newInvoice.Terms.FinalCltvDelta,
|
|
||||||
),
|
|
||||||
Expiry: int32(newInvoice.Terms.Expiry.Seconds()),
|
|
||||||
// Note: keysend invoices don't have a payment request.
|
|
||||||
PaymentRequest: sqldb.SQLStr(string(
|
|
||||||
newInvoice.PaymentRequest),
|
|
||||||
),
|
|
||||||
PaymentRequestHash: paymentRequestHash,
|
|
||||||
State: int16(newInvoice.State),
|
|
||||||
AmountPaidMsat: int64(newInvoice.AmtPaid),
|
|
||||||
IsAmp: newInvoice.IsAMP(),
|
|
||||||
IsHodl: newInvoice.HodlInvoice,
|
|
||||||
IsKeysend: newInvoice.IsKeysend(),
|
|
||||||
CreatedAt: newInvoice.CreationDate.UTC(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Some invoices may not have a preimage, like in the case of
|
|
||||||
// HODL invoices.
|
|
||||||
if newInvoice.Terms.PaymentPreimage != nil {
|
|
||||||
preimage := *newInvoice.Terms.PaymentPreimage
|
|
||||||
if preimage == UnknownPreimage {
|
|
||||||
return errors.New("cannot use all-zeroes " +
|
|
||||||
"preimage")
|
|
||||||
}
|
|
||||||
params.Preimage = preimage[:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Some non MPP payments may have the default (invalid) value.
|
|
||||||
if newInvoice.Terms.PaymentAddr != BlankPayAddr {
|
|
||||||
params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
|
||||||
var err error
|
var err error
|
||||||
invoiceID, err = db.InsertInvoice(ctx, params)
|
invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to insert invoice: %w", err)
|
return fmt.Errorf("unable to insert invoice: %w", err)
|
||||||
}
|
}
|
||||||
@ -312,22 +357,31 @@ func (i *SQLStore) AddInvoice(ctx context.Context,
|
|||||||
return newInvoice.AddIndex, nil
|
return newInvoice.AddIndex, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// fetchInvoice fetches the common invoice data and the AMP state for the
|
// getInvoiceByRef fetches the invoice with the given reference. The reference
|
||||||
// invoice with the given reference.
|
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
|
||||||
func (i *SQLStore) fetchInvoice(ctx context.Context,
|
func getInvoiceByRef(ctx context.Context,
|
||||||
db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) {
|
db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
|
||||||
|
|
||||||
|
// If the reference is empty, we can't look up the invoice.
|
||||||
if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
|
if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
|
||||||
return nil, ErrInvoiceNotFound
|
return sqlc.Invoice{}, ErrInvoiceNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
// If the reference is a hash only, we can look up the invoice directly
|
||||||
invoice *Invoice
|
// by the payment hash which is faster.
|
||||||
params sqlc.GetInvoiceParams
|
if ref.IsHashOnly() {
|
||||||
)
|
invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return sqlc.Invoice{}, ErrInvoiceNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return invoice, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise the reference may include more fields, so we'll need to
|
||||||
|
// assemble the query parameters based on the fields that are set.
|
||||||
|
var params sqlc.GetInvoiceParams
|
||||||
|
|
||||||
// Given all invoices are uniquely identified by their payment hash,
|
|
||||||
// we can use it to query a specific invoice.
|
|
||||||
if ref.PayHash() != nil {
|
if ref.PayHash() != nil {
|
||||||
params.Hash = ref.PayHash()[:]
|
params.Hash = ref.PayHash()[:]
|
||||||
}
|
}
|
||||||
@ -363,18 +417,34 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
|
|||||||
} else {
|
} else {
|
||||||
rows, err = db.GetInvoice(ctx, params)
|
rows, err = db.GetInvoice(ctx, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case len(rows) == 0:
|
case len(rows) == 0:
|
||||||
return nil, ErrInvoiceNotFound
|
return sqlc.Invoice{}, ErrInvoiceNotFound
|
||||||
|
|
||||||
case len(rows) > 1:
|
case len(rows) > 1:
|
||||||
// In case the reference is ambiguous, meaning it matches more
|
// In case the reference is ambiguous, meaning it matches more
|
||||||
// than one invoice, we'll return an error.
|
// than one invoice, we'll return an error.
|
||||||
return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
|
return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
|
||||||
ref.String(), spew.Sdump(rows))
|
"%s: %s", ref.String(), spew.Sdump(rows))
|
||||||
|
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return nil, fmt.Errorf("unable to fetch invoice: %w", err)
|
return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchInvoice fetches the common invoice data and the AMP state for the
|
||||||
|
// invoice with the given reference.
|
||||||
|
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
|
||||||
|
*Invoice, error) {
|
||||||
|
|
||||||
|
// Fetch the invoice from the database.
|
||||||
|
sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -391,8 +461,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
|
|||||||
fetchAmpHtlcs = true
|
fetchAmpHtlcs = true
|
||||||
|
|
||||||
case HtlcSetOnlyModifier:
|
case HtlcSetOnlyModifier:
|
||||||
// In this case we'll fetch all AMP HTLCs for the
|
// In this case we'll fetch all AMP HTLCs for the specified set
|
||||||
// specified set id.
|
// id.
|
||||||
if ref.SetID() == nil {
|
if ref.SetID() == nil {
|
||||||
return nil, fmt.Errorf("set ID is required to use " +
|
return nil, fmt.Errorf("set ID is required to use " +
|
||||||
"the HTLC set only modifier")
|
"the HTLC set only modifier")
|
||||||
@ -412,8 +482,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the rest of the invoice data and fill the invoice struct.
|
// Fetch the rest of the invoice data and fill the invoice struct.
|
||||||
_, invoice, err = fetchInvoiceData(
|
_, invoice, err := fetchInvoiceData(
|
||||||
ctx, db, rows[0], setID, fetchAmpHtlcs,
|
ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -616,7 +686,7 @@ func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
|
|||||||
|
|
||||||
invoiceKeys[key] = struct{}{}
|
invoiceKeys[key] = struct{}{}
|
||||||
|
|
||||||
if htlc.State != HtlcStateCanceled { //nolint: ll
|
if htlc.State != HtlcStateCanceled {
|
||||||
amtPaid += htlc.Amt
|
amtPaid += htlc.Amt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -646,7 +716,7 @@ func (i *SQLStore) LookupInvoice(ctx context.Context,
|
|||||||
|
|
||||||
readTxOpt := NewSQLInvoiceQueryReadTx()
|
readTxOpt := NewSQLInvoiceQueryReadTx()
|
||||||
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
|
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
|
||||||
invoice, err = i.fetchInvoice(ctx, db, ref)
|
invoice, err = fetchInvoice(ctx, db, ref)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}, func() {})
|
}, func() {})
|
||||||
@ -1347,7 +1417,7 @@ func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
|
|||||||
ref.refModifier = HtlcSetOnlyModifier
|
ref.refModifier = HtlcSetOnlyModifier
|
||||||
}
|
}
|
||||||
|
|
||||||
invoice, err := i.fetchInvoice(ctx, db, ref)
|
invoice, err := fetchInvoice(ctx, db, ref)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1506,13 +1576,6 @@ func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
|
|||||||
|
|
||||||
if len(htlcs) > 0 {
|
if len(htlcs) > 0 {
|
||||||
invoice.Htlcs = htlcs
|
invoice.Htlcs = htlcs
|
||||||
var amountPaid lnwire.MilliSatoshi
|
|
||||||
for _, htlc := range htlcs {
|
|
||||||
if htlc.State == HtlcStateSettled {
|
|
||||||
amountPaid += htlc.Amt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
invoice.AmtPaid = amountPaid
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return hash, invoice, nil
|
return hash, invoice, nil
|
||||||
|
BIN
invoices/testdata/channel.db
vendored
Normal file
BIN
invoices/testdata/channel.db
vendored
Normal file
Binary file not shown.
@ -626,10 +626,6 @@ var allTestCases = []*lntest.TestCase{
|
|||||||
Name: "open channel locked balance",
|
Name: "open channel locked balance",
|
||||||
TestFunc: testOpenChannelLockedBalance,
|
TestFunc: testOpenChannelLockedBalance,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Name: "nativesql no migration",
|
|
||||||
TestFunc: testNativeSQLNoMigration,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Name: "sweep cpfp anchor outgoing timeout",
|
Name: "sweep cpfp anchor outgoing timeout",
|
||||||
TestFunc: testSweepCPFPAnchorOutgoingTimeout,
|
TestFunc: testSweepCPFPAnchorOutgoingTimeout,
|
||||||
@ -682,6 +678,10 @@ var allTestCases = []*lntest.TestCase{
|
|||||||
Name: "quiescence",
|
Name: "quiescence",
|
||||||
TestFunc: testQuiescence,
|
TestFunc: testQuiescence,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "invoice migration",
|
||||||
|
TestFunc: testInvoiceMigration,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// appendPrefixed is used to add a prefix to each test name in the subtests
|
// appendPrefixed is used to add a prefix to each test name in the subtests
|
||||||
|
303
itest/lnd_invoice_migration_test.go
Normal file
303
itest/lnd_invoice_migration_test.go
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
package itest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"path"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
|
"github.com/lightningnetwork/lnd/invoices"
|
||||||
|
"github.com/lightningnetwork/lnd/kvdb"
|
||||||
|
"github.com/lightningnetwork/lnd/kvdb/postgres"
|
||||||
|
"github.com/lightningnetwork/lnd/kvdb/sqlbase"
|
||||||
|
"github.com/lightningnetwork/lnd/kvdb/sqlite"
|
||||||
|
"github.com/lightningnetwork/lnd/lncfg"
|
||||||
|
"github.com/lightningnetwork/lnd/lnrpc"
|
||||||
|
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
|
||||||
|
"github.com/lightningnetwork/lnd/lntest"
|
||||||
|
"github.com/lightningnetwork/lnd/lntest/node"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func openChannelDB(ht *lntest.HarnessTest, hn *node.HarnessNode) *channeldb.DB {
|
||||||
|
sqlbase.Init(0)
|
||||||
|
var (
|
||||||
|
backend kvdb.Backend
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
switch hn.Cfg.DBBackend {
|
||||||
|
case node.BackendSqlite:
|
||||||
|
backend, err = kvdb.Open(
|
||||||
|
kvdb.SqliteBackendName,
|
||||||
|
ht.Context(),
|
||||||
|
&sqlite.Config{
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
BusyTimeout: defaultTimeout,
|
||||||
|
},
|
||||||
|
hn.Cfg.DBDir(), lncfg.SqliteChannelDBName,
|
||||||
|
lncfg.NSChannelDB,
|
||||||
|
)
|
||||||
|
require.NoError(ht, err)
|
||||||
|
|
||||||
|
case node.BackendPostgres:
|
||||||
|
backend, err = kvdb.Open(
|
||||||
|
kvdb.PostgresBackendName, ht.Context(),
|
||||||
|
&postgres.Config{
|
||||||
|
Dsn: hn.Cfg.PostgresDsn,
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
}, lncfg.NSChannelDB,
|
||||||
|
)
|
||||||
|
require.NoError(ht, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := channeldb.CreateWithBackend(backend)
|
||||||
|
require.NoError(ht, err)
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func openNativeSQLInvoiceDB(ht *lntest.HarnessTest,
|
||||||
|
hn *node.HarnessNode) invoices.InvoiceDB {
|
||||||
|
|
||||||
|
var db *sqldb.BaseDB
|
||||||
|
|
||||||
|
switch hn.Cfg.DBBackend {
|
||||||
|
case node.BackendSqlite:
|
||||||
|
sqliteStore, err := sqldb.NewSqliteStore(
|
||||||
|
&sqldb.SqliteConfig{
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
BusyTimeout: defaultTimeout,
|
||||||
|
},
|
||||||
|
path.Join(
|
||||||
|
hn.Cfg.DBDir(),
|
||||||
|
lncfg.SqliteNativeDBName,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
require.NoError(ht, err)
|
||||||
|
db = sqliteStore.BaseDB
|
||||||
|
|
||||||
|
case node.BackendPostgres:
|
||||||
|
postgresStore, err := sqldb.NewPostgresStore(
|
||||||
|
&sqldb.PostgresConfig{
|
||||||
|
Dsn: hn.Cfg.PostgresDsn,
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NoError(ht, err)
|
||||||
|
db = postgresStore.BaseDB
|
||||||
|
}
|
||||||
|
|
||||||
|
executor := sqldb.NewTransactionExecutor(
|
||||||
|
db, func(tx *sql.Tx) invoices.SQLInvoiceQueries {
|
||||||
|
return db.WithTx(tx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return invoices.NewSQLStore(
|
||||||
|
executor, clock.NewDefaultClock(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// clampTime truncates the time of the passed invoice to the microsecond level.
|
||||||
|
func clampTime(invoice *invoices.Invoice) {
|
||||||
|
trunc := func(t time.Time) time.Time {
|
||||||
|
return t.Truncate(time.Microsecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
invoice.CreationDate = trunc(invoice.CreationDate)
|
||||||
|
|
||||||
|
if !invoice.SettleDate.IsZero() {
|
||||||
|
invoice.SettleDate = trunc(invoice.SettleDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
if invoice.IsAMP() {
|
||||||
|
for setID, ampState := range invoice.AMPState {
|
||||||
|
if ampState.SettleDate.IsZero() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ampState.SettleDate = trunc(ampState.SettleDate)
|
||||||
|
invoice.AMPState[setID] = ampState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, htlc := range invoice.Htlcs {
|
||||||
|
if !htlc.AcceptTime.IsZero() {
|
||||||
|
htlc.AcceptTime = trunc(htlc.AcceptTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !htlc.ResolveTime.IsZero() {
|
||||||
|
htlc.ResolveTime = trunc(htlc.ResolveTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testInvoiceMigration tests that the invoice migration from the old KV store
|
||||||
|
// to the new native SQL store works as expected.
|
||||||
|
func testInvoiceMigration(ht *lntest.HarnessTest) {
|
||||||
|
alice := ht.NewNodeWithCoins("Alice", nil)
|
||||||
|
bob := ht.NewNodeWithCoins("Bob", []string{"--accept-amp"})
|
||||||
|
|
||||||
|
// Make sure we run the test with SQLite or Postgres.
|
||||||
|
if bob.Cfg.DBBackend != node.BackendSqlite &&
|
||||||
|
bob.Cfg.DBBackend != node.BackendPostgres {
|
||||||
|
|
||||||
|
ht.Skip("node not running with SQLite or Postgres")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip the test if the node is already running with native SQL.
|
||||||
|
if bob.Cfg.NativeSQL {
|
||||||
|
ht.Skip("node already running with native SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
ht.EnsureConnected(alice, bob)
|
||||||
|
cp := ht.OpenChannel(
|
||||||
|
alice, bob, lntest.OpenChannelParams{
|
||||||
|
Amt: 1000000,
|
||||||
|
PushAmt: 500000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Alice and bob should have one channel open with each other now.
|
||||||
|
ht.AssertNodeNumChannels(alice, 1)
|
||||||
|
ht.AssertNodeNumChannels(bob, 1)
|
||||||
|
|
||||||
|
// Step 1: Add 10 normal invoices and pay 5 of them.
|
||||||
|
normalInvoices := make([]*lnrpc.AddInvoiceResponse, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
invoice := &lnrpc.Invoice{
|
||||||
|
Value: int64(1000 + i*100), // Varying amounts
|
||||||
|
IsAmp: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := bob.RPC.AddInvoice(invoice)
|
||||||
|
normalInvoices[i] = resp
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, inv := range normalInvoices {
|
||||||
|
sendReq := &routerrpc.SendPaymentRequest{
|
||||||
|
PaymentRequest: inv.PaymentRequest,
|
||||||
|
TimeoutSeconds: 60,
|
||||||
|
FeeLimitMsat: noFeeLimitMsat,
|
||||||
|
}
|
||||||
|
|
||||||
|
ht.SendPaymentAssertSettled(alice, sendReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Add 10 AMP invoices and send multiple payments to 5 of them.
|
||||||
|
ampInvoices := make([]*lnrpc.AddInvoiceResponse, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
invoice := &lnrpc.Invoice{
|
||||||
|
Value: int64(2000 + i*200), // Varying amounts
|
||||||
|
IsAmp: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := bob.RPC.AddInvoice(invoice)
|
||||||
|
ampInvoices[i] = resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select the first 5 invoices to send multiple AMP payments.
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
inv := ampInvoices[i]
|
||||||
|
|
||||||
|
// Send 3 payments to each.
|
||||||
|
for j := 0; j < 3; j++ {
|
||||||
|
payReq := &routerrpc.SendPaymentRequest{
|
||||||
|
PaymentRequest: inv.PaymentRequest,
|
||||||
|
TimeoutSeconds: 60,
|
||||||
|
FeeLimitMsat: noFeeLimitMsat,
|
||||||
|
Amp: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send a normal AMP payment first, then a spontaneous
|
||||||
|
// AMP payment.
|
||||||
|
ht.SendPaymentAssertSettled(alice, payReq)
|
||||||
|
|
||||||
|
// Generate an external payment address when attempting
|
||||||
|
// to pseudo-reuse an AMP invoice. When using an
|
||||||
|
// external payment address, we'll also expect an extra
|
||||||
|
// invoice to appear in the ListInvoices response, since
|
||||||
|
// a new invoice will be JIT inserted under a different
|
||||||
|
// payment address than the one in the invoice.
|
||||||
|
//
|
||||||
|
// NOTE: This will only work when the peer has
|
||||||
|
// spontaneous AMP payments enabled otherwise no invoice
|
||||||
|
// under a different payment_addr will be found.
|
||||||
|
payReq.PaymentAddr = ht.Random32Bytes()
|
||||||
|
ht.SendPaymentAssertSettled(alice, payReq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can close the channel now.
|
||||||
|
ht.CloseChannel(alice, cp)
|
||||||
|
|
||||||
|
// Now stop Bob so we can open the DB for examination.
|
||||||
|
require.NoError(ht, bob.Stop())
|
||||||
|
|
||||||
|
// Open the KV channel DB.
|
||||||
|
db := openChannelDB(ht, bob)
|
||||||
|
|
||||||
|
query := invoices.InvoiceQuery{
|
||||||
|
IndexOffset: 0,
|
||||||
|
// As a sanity check, fetch more invoices than we have
|
||||||
|
// to ensure that we did not add any extra invoices.
|
||||||
|
NumMaxInvoices: 9999,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch all invoices and make sure we have 35 in total.
|
||||||
|
result1, err := db.QueryInvoices(ht.Context(), query)
|
||||||
|
require.NoError(ht, err)
|
||||||
|
require.Equal(ht, 35, len(result1.Invoices))
|
||||||
|
numInvoices := len(result1.Invoices)
|
||||||
|
|
||||||
|
bob.SetExtraArgs([]string{"--db.use-native-sql"})
|
||||||
|
|
||||||
|
// Now run the migration flow three times to ensure that each run is
|
||||||
|
// idempotent.
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
// Start bob with the native SQL flag set. This will trigger the
|
||||||
|
// migration to run.
|
||||||
|
require.NoError(ht, bob.Start(ht.Context()))
|
||||||
|
|
||||||
|
// At this point the migration should have completed and the
|
||||||
|
// node should be running with native SQL. Now we'll stop Bob
|
||||||
|
// again so we can safely examine the database.
|
||||||
|
require.NoError(ht, bob.Stop())
|
||||||
|
|
||||||
|
// Now we'll open the database with the native SQL backend and
|
||||||
|
// fetch the invoices again to ensure that they were migrated
|
||||||
|
// correctly.
|
||||||
|
sqlInvoiceDB := openNativeSQLInvoiceDB(ht, bob)
|
||||||
|
result2, err := sqlInvoiceDB.QueryInvoices(ht.Context(), query)
|
||||||
|
require.NoError(ht, err)
|
||||||
|
|
||||||
|
require.Equal(ht, numInvoices, len(result2.Invoices))
|
||||||
|
|
||||||
|
// Simply zero out the add index so we don't fail on that when
|
||||||
|
// comparing.
|
||||||
|
for i := 0; i < numInvoices; i++ {
|
||||||
|
result1.Invoices[i].AddIndex = 0
|
||||||
|
result2.Invoices[i].AddIndex = 0
|
||||||
|
|
||||||
|
// Clamp the precision to microseconds. Note that we
|
||||||
|
// need to override both invoices as the original
|
||||||
|
// invoice is coming from KV database, it was stored as
|
||||||
|
// a binary serialized Go time.Time value which has
|
||||||
|
// nanosecond precision. The migrated invoice is stored
|
||||||
|
// in SQL in PostgreSQL has microsecond precision while
|
||||||
|
// in SQLite it has nanosecond precision if using TEXT
|
||||||
|
// storage class.
|
||||||
|
clampTime(&result1.Invoices[i])
|
||||||
|
clampTime(&result2.Invoices[i])
|
||||||
|
require.Equal(
|
||||||
|
ht, result1.Invoices[i], result2.Invoices[i],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start Bob again so the test can complete.
|
||||||
|
require.NoError(ht, bob.Start(ht.Context()))
|
||||||
|
}
|
@ -1,7 +1,6 @@
|
|||||||
package itest
|
package itest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@ -1243,44 +1242,6 @@ func testSignVerifyMessageWithAddr(ht *lntest.HarnessTest) {
|
|||||||
require.False(ht, respValid.Valid, "external signature did validate")
|
require.False(ht, respValid.Valid, "external signature did validate")
|
||||||
}
|
}
|
||||||
|
|
||||||
// testNativeSQLNoMigration tests that nodes that have invoices would not start
|
|
||||||
// up with native SQL enabled, as we don't currently support migration of KV
|
|
||||||
// invoices to the new SQL schema.
|
|
||||||
func testNativeSQLNoMigration(ht *lntest.HarnessTest) {
|
|
||||||
alice := ht.NewNode("Alice", nil)
|
|
||||||
|
|
||||||
// Make sure we run the test with SQLite or Postgres.
|
|
||||||
if alice.Cfg.DBBackend != node.BackendSqlite &&
|
|
||||||
alice.Cfg.DBBackend != node.BackendPostgres {
|
|
||||||
|
|
||||||
ht.Skip("node not running with SQLite or Postgres")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip the test if the node is already running with native SQL.
|
|
||||||
if alice.Cfg.NativeSQL {
|
|
||||||
ht.Skip("node already running with native SQL")
|
|
||||||
}
|
|
||||||
|
|
||||||
alice.RPC.AddInvoice(&lnrpc.Invoice{
|
|
||||||
Value: 10_000,
|
|
||||||
})
|
|
||||||
|
|
||||||
alice.SetExtraArgs([]string{"--db.use-native-sql"})
|
|
||||||
|
|
||||||
// Restart the node manually as we're really only interested in the
|
|
||||||
// startup error.
|
|
||||||
require.NoError(ht, alice.Stop())
|
|
||||||
require.NoError(ht, alice.StartLndCmd(context.Background()))
|
|
||||||
|
|
||||||
// We expect the node to fail to start up with native SQL enabled, as we
|
|
||||||
// have an invoice in the KV store.
|
|
||||||
require.Error(ht, alice.WaitForProcessExit())
|
|
||||||
|
|
||||||
// Reset the extra args and restart alice.
|
|
||||||
alice.SetExtraArgs(nil)
|
|
||||||
require.NoError(ht, alice.Start(ht.Context()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// testSendSelectedCoins tests that we're able to properly send the selected
|
// testSendSelectedCoins tests that we're able to properly send the selected
|
||||||
// coins from the wallet to a single target address.
|
// coins from the wallet to a single target address.
|
||||||
func testSendSelectedCoins(ht *lntest.HarnessTest) {
|
func testSendSelectedCoins(ht *lntest.HarnessTest) {
|
||||||
|
19
lncfg/db.go
19
lncfg/db.go
@ -87,6 +87,8 @@ type DB struct {
|
|||||||
|
|
||||||
UseNativeSQL bool `long:"use-native-sql" description:"Use native SQL for tables that already support it."`
|
UseNativeSQL bool `long:"use-native-sql" description:"Use native SQL for tables that already support it."`
|
||||||
|
|
||||||
|
SkipSQLInvoiceMigration bool `long:"skip-sql-invoice-migration" description:"Do not migrate invoices stored in our key-value database to native SQL."`
|
||||||
|
|
||||||
NoGraphCache bool `long:"no-graph-cache" description:"Don't use the in-memory graph cache for path finding. Much slower but uses less RAM. Can only be used with a bolt database backend."`
|
NoGraphCache bool `long:"no-graph-cache" description:"Don't use the in-memory graph cache for path finding. Much slower but uses less RAM. Can only be used with a bolt database backend."`
|
||||||
|
|
||||||
PruneRevocation bool `long:"prune-revocation" description:"Run the optional migration that prunes the revocation logs to save disk space."`
|
PruneRevocation bool `long:"prune-revocation" description:"Run the optional migration that prunes the revocation logs to save disk space."`
|
||||||
@ -116,6 +118,7 @@ func DefaultDB() *DB {
|
|||||||
BusyTimeout: defaultSqliteBusyTimeout,
|
BusyTimeout: defaultSqliteBusyTimeout,
|
||||||
},
|
},
|
||||||
UseNativeSQL: false,
|
UseNativeSQL: false,
|
||||||
|
SkipSQLInvoiceMigration: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,10 +234,10 @@ type DatabaseBackends struct {
|
|||||||
// the underlying wallet database from.
|
// the underlying wallet database from.
|
||||||
WalletDB btcwallet.LoaderOption
|
WalletDB btcwallet.LoaderOption
|
||||||
|
|
||||||
// NativeSQLStore is a pointer to a native SQL store that can be used
|
// NativeSQLStore holds a reference to the native SQL store that can
|
||||||
// for native SQL queries for tables that already support it. This may
|
// be used for native SQL queries for tables that already support it.
|
||||||
// be nil if the use-native-sql flag was not set.
|
// This may be nil if the use-native-sql flag was not set.
|
||||||
NativeSQLStore *sqldb.BaseDB
|
NativeSQLStore sqldb.DB
|
||||||
|
|
||||||
// Remote indicates whether the database backends are remote, possibly
|
// Remote indicates whether the database backends are remote, possibly
|
||||||
// replicated instances or local bbolt or sqlite backed databases.
|
// replicated instances or local bbolt or sqlite backed databases.
|
||||||
@ -449,7 +452,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
|||||||
}
|
}
|
||||||
closeFuncs[NSWalletDB] = postgresWalletBackend.Close
|
closeFuncs[NSWalletDB] = postgresWalletBackend.Close
|
||||||
|
|
||||||
var nativeSQLStore *sqldb.BaseDB
|
var nativeSQLStore sqldb.DB
|
||||||
if db.UseNativeSQL {
|
if db.UseNativeSQL {
|
||||||
nativePostgresStore, err := sqldb.NewPostgresStore(
|
nativePostgresStore, err := sqldb.NewPostgresStore(
|
||||||
db.Postgres,
|
db.Postgres,
|
||||||
@ -459,7 +462,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
|||||||
"native postgres store: %v", err)
|
"native postgres store: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nativeSQLStore = nativePostgresStore.BaseDB
|
nativeSQLStore = nativePostgresStore
|
||||||
closeFuncs[PostgresBackend] = nativePostgresStore.Close
|
closeFuncs[PostgresBackend] = nativePostgresStore.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -571,7 +574,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
|||||||
}
|
}
|
||||||
closeFuncs[NSWalletDB] = sqliteWalletBackend.Close
|
closeFuncs[NSWalletDB] = sqliteWalletBackend.Close
|
||||||
|
|
||||||
var nativeSQLStore *sqldb.BaseDB
|
var nativeSQLStore sqldb.DB
|
||||||
if db.UseNativeSQL {
|
if db.UseNativeSQL {
|
||||||
nativeSQLiteStore, err := sqldb.NewSqliteStore(
|
nativeSQLiteStore, err := sqldb.NewSqliteStore(
|
||||||
db.Sqlite,
|
db.Sqlite,
|
||||||
@ -582,7 +585,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
|||||||
"native SQLite store: %v", err)
|
"native SQLite store: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nativeSQLStore = nativeSQLiteStore.BaseDB
|
nativeSQLStore = nativeSQLiteStore
|
||||||
closeFuncs[SqliteBackend] = nativeSQLiteStore.Close
|
closeFuncs[SqliteBackend] = nativeSQLiteStore.Close
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1472,6 +1472,9 @@
|
|||||||
; own risk.
|
; own risk.
|
||||||
; db.use-native-sql=false
|
; db.use-native-sql=false
|
||||||
|
|
||||||
|
; If set to true, native SQL invoice migration will be skipped. Note that this
|
||||||
|
; option is intended for users who experience non-resolvable migration errors.
|
||||||
|
; db.skip-sql-invoice-migration=false
|
||||||
|
|
||||||
[etcd]
|
[etcd]
|
||||||
|
|
||||||
|
@ -2,12 +2,40 @@
|
|||||||
|
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
# restore_files is a function to restore original schema files.
|
||||||
|
restore_files() {
|
||||||
|
echo "Restoring SQLite bigint patch..."
|
||||||
|
for file in sqldb/sqlc/migrations/*.up.sql.bak; do
|
||||||
|
mv "$file" "${file%.bak}"
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Set trap to call restore_files on script exit. This makes sure the old files
|
||||||
|
# are always restored.
|
||||||
|
trap restore_files EXIT
|
||||||
|
|
||||||
# Directory of the script file, independent of where it's called from.
|
# Directory of the script file, independent of where it's called from.
|
||||||
DIR="$(cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd)"
|
DIR="$(cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd)"
|
||||||
# Use the user's cache directories
|
# Use the user's cache directories
|
||||||
GOCACHE=`go env GOCACHE`
|
GOCACHE=`go env GOCACHE`
|
||||||
GOMODCACHE=`go env GOMODCACHE`
|
GOMODCACHE=`go env GOMODCACHE`
|
||||||
|
|
||||||
|
# SQLite doesn't support "BIGINT PRIMARY KEY" for auto-incrementing primary
|
||||||
|
# keys, only "INTEGER PRIMARY KEY". Internally it uses 64-bit integers for
|
||||||
|
# numbers anyway, independent of the column type. So we can just use
|
||||||
|
# "INTEGER PRIMARY KEY" and it will work the same under the hood, giving us
|
||||||
|
# auto incrementing 64-bit integers.
|
||||||
|
# _BUT_, sqlc will generate Go code with int32 if we use "INTEGER PRIMARY KEY",
|
||||||
|
# even though we want int64. So before we run sqlc, we need to patch the
|
||||||
|
# source schema SQL files to use "BIGINT PRIMARY KEY" instead of "INTEGER
|
||||||
|
# PRIMARY KEY".
|
||||||
|
echo "Applying SQLite bigint patch..."
|
||||||
|
for file in sqldb/sqlc/migrations/*.up.sql; do
|
||||||
|
echo "Patching $file"
|
||||||
|
sed -i.bak -E 's/INTEGER PRIMARY KEY/BIGINT PRIMARY KEY/g' "$file"
|
||||||
|
done
|
||||||
|
|
||||||
echo "Generating sql models and queries in go..."
|
echo "Generating sql models and queries in go..."
|
||||||
|
|
||||||
docker run \
|
docker run \
|
||||||
|
@ -355,6 +355,18 @@ func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context,
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB is an interface that represents a generic SQL database. It provides
|
||||||
|
// methods to apply migrations and access the underlying database connection.
|
||||||
|
type DB interface {
|
||||||
|
// GetBaseDB returns the underlying BaseDB instance.
|
||||||
|
GetBaseDB() *BaseDB
|
||||||
|
|
||||||
|
// ApplyAllMigrations applies all migrations to the database including
|
||||||
|
// both sqlc and custom in-code migrations.
|
||||||
|
ApplyAllMigrations(ctx context.Context,
|
||||||
|
customMigrations []MigrationConfig) error
|
||||||
|
}
|
||||||
|
|
||||||
// BaseDB is the base database struct that each implementation can embed to
|
// BaseDB is the base database struct that each implementation can embed to
|
||||||
// gain some common functionality.
|
// gain some common functionality.
|
||||||
type BaseDB struct {
|
type BaseDB struct {
|
||||||
|
@ -2,22 +2,118 @@ package sqldb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btclog/v2"
|
"github.com/btcsuite/btclog/v2"
|
||||||
"github.com/golang-migrate/migrate/v4"
|
"github.com/golang-migrate/migrate/v4"
|
||||||
"github.com/golang-migrate/migrate/v4/database"
|
"github.com/golang-migrate/migrate/v4/database"
|
||||||
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// migrationConfig defines a list of migrations to be applied to the
|
||||||
|
// database. Each migration is assigned a version number, determining
|
||||||
|
// its execution order.
|
||||||
|
// The schema version, tracked by golang-migrate, ensures migrations are
|
||||||
|
// applied to the correct schema. For migrations involving only schema
|
||||||
|
// changes, the migration function can be left nil. For custom
|
||||||
|
// migrations an implemented migration function is required.
|
||||||
|
//
|
||||||
|
// NOTE: The migration function may have runtime dependencies, which
|
||||||
|
// must be injected during runtime.
|
||||||
|
migrationConfig = []MigrationConfig{
|
||||||
|
{
|
||||||
|
Name: "000001_invoices",
|
||||||
|
Version: 1,
|
||||||
|
SchemaVersion: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "000002_amp_invoices",
|
||||||
|
Version: 2,
|
||||||
|
SchemaVersion: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "000003_invoice_events",
|
||||||
|
Version: 3,
|
||||||
|
SchemaVersion: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "000004_invoice_expiry_fix",
|
||||||
|
Version: 4,
|
||||||
|
SchemaVersion: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "000005_migration_tracker",
|
||||||
|
Version: 5,
|
||||||
|
SchemaVersion: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "000006_invoice_migration",
|
||||||
|
Version: 6,
|
||||||
|
SchemaVersion: 6,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "kv_invoice_migration",
|
||||||
|
Version: 7,
|
||||||
|
SchemaVersion: 6,
|
||||||
|
// A migration function is may be attached to this
|
||||||
|
// migration to migrate KV invoices to the native SQL
|
||||||
|
// schema. This is optional and can be disabled by the
|
||||||
|
// user if necessary.
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// MigrationConfig is a configuration struct that describes SQL migrations. Each
|
||||||
|
// migration is associated with a specific schema version and a global database
|
||||||
|
// version. Migrations are applied in the order of their global database
|
||||||
|
// version. If a migration includes a non-nil MigrationFn, it is executed after
|
||||||
|
// the SQL schema has been migrated to the corresponding schema version.
|
||||||
|
type MigrationConfig struct {
|
||||||
|
// Name is the name of the migration.
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Version represents the "global" database version for this migration.
|
||||||
|
// Unlike the schema version tracked by golang-migrate, it encompasses
|
||||||
|
// all migrations, including those managed by golang-migrate as well
|
||||||
|
// as custom in-code migrations.
|
||||||
|
Version int
|
||||||
|
|
||||||
|
// SchemaVersion represents the schema version tracked by golang-migrate
|
||||||
|
// at which the migration is applied.
|
||||||
|
SchemaVersion int
|
||||||
|
|
||||||
|
// MigrationFn is the function executed for custom migrations at the
|
||||||
|
// specified version. It is used to handle migrations that cannot be
|
||||||
|
// performed through SQL alone. If set to nil, no custom migration is
|
||||||
|
// applied.
|
||||||
|
MigrationFn func(tx *sqlc.Queries) error
|
||||||
|
}
|
||||||
|
|
||||||
// MigrationTarget is a functional option that can be passed to applyMigrations
|
// MigrationTarget is a functional option that can be passed to applyMigrations
|
||||||
// to specify a target version to migrate to.
|
// to specify a target version to migrate to.
|
||||||
type MigrationTarget func(mig *migrate.Migrate) error
|
type MigrationTarget func(mig *migrate.Migrate) error
|
||||||
|
|
||||||
|
// MigrationExecutor is an interface that abstracts the migration functionality.
|
||||||
|
type MigrationExecutor interface {
|
||||||
|
// ExecuteMigrations runs database migrations up to the specified target
|
||||||
|
// version or all migrations if no target is specified. A migration may
|
||||||
|
// include a schema change, a custom migration function, or both.
|
||||||
|
// Developers must ensure that migrations are defined in the correct
|
||||||
|
// order. Migration details are stored in the global variable
|
||||||
|
// migrationConfig.
|
||||||
|
ExecuteMigrations(target MigrationTarget) error
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// TargetLatest is a MigrationTarget that migrates to the latest
|
// TargetLatest is a MigrationTarget that migrates to the latest
|
||||||
// version available.
|
// version available.
|
||||||
@ -34,6 +130,14 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// GetMigrations returns a copy of the migration configuration.
|
||||||
|
func GetMigrations() []MigrationConfig {
|
||||||
|
migrations := make([]MigrationConfig, len(migrationConfig))
|
||||||
|
copy(migrations, migrationConfig)
|
||||||
|
|
||||||
|
return migrations
|
||||||
|
}
|
||||||
|
|
||||||
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
|
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
|
||||||
// used to log migrations.
|
// used to log migrations.
|
||||||
type migrationLogger struct {
|
type migrationLogger struct {
|
||||||
@ -216,3 +320,117 @@ func (t *replacerFile) Close() error {
|
|||||||
// instance, so there's nothing to do for us here.
|
// instance, so there's nothing to do for us here.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MigrationTxOptions is the implementation of the TxOptions interface for
|
||||||
|
// migration transactions.
|
||||||
|
type MigrationTxOptions struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadOnly returns false to indicate that migration transactions are not read
|
||||||
|
// only.
|
||||||
|
func (m *MigrationTxOptions) ReadOnly() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyMigrations applies the provided migrations to the database in sequence.
|
||||||
|
// It ensures migrations are executed in the correct order, applying both custom
|
||||||
|
// migration functions and SQL migrations as needed.
|
||||||
|
func ApplyMigrations(ctx context.Context, db *BaseDB,
|
||||||
|
migrator MigrationExecutor, migrations []MigrationConfig) error {
|
||||||
|
|
||||||
|
// Ensure that the migrations are sorted by version.
|
||||||
|
for i := 0; i < len(migrations); i++ {
|
||||||
|
if migrations[i].Version != i+1 {
|
||||||
|
return fmt.Errorf("migration version %d is out of "+
|
||||||
|
"order. Expected %d", migrations[i].Version,
|
||||||
|
i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Construct a transaction executor to apply custom migrations.
|
||||||
|
executor := NewTransactionExecutor(db, func(tx *sql.Tx) *sqlc.Queries {
|
||||||
|
return db.WithTx(tx)
|
||||||
|
})
|
||||||
|
|
||||||
|
currentVersion := 0
|
||||||
|
version, err := db.GetDatabaseVersion(ctx)
|
||||||
|
if !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting current database "+
|
||||||
|
"version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentVersion = int(version)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, migration := range migrations {
|
||||||
|
if migration.Version <= currentVersion {
|
||||||
|
log.Infof("Skipping migration '%s' (version %d) as it "+
|
||||||
|
"has already been applied", migration.Name,
|
||||||
|
migration.Version)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Migrating SQL schema to version %d",
|
||||||
|
migration.SchemaVersion)
|
||||||
|
|
||||||
|
// Execute SQL schema migrations up to the target version.
|
||||||
|
err = migrator.ExecuteMigrations(
|
||||||
|
TargetVersion(uint(migration.SchemaVersion)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error executing schema migrations "+
|
||||||
|
"to target version %d: %w",
|
||||||
|
migration.SchemaVersion, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var opts MigrationTxOptions
|
||||||
|
|
||||||
|
// Run the custom migration as a transaction to ensure
|
||||||
|
// atomicity. If successful, mark the migration as complete in
|
||||||
|
// the migration tracker table.
|
||||||
|
err = executor.ExecTx(ctx, &opts, func(tx *sqlc.Queries) error {
|
||||||
|
// Apply the migration function if one is provided.
|
||||||
|
if migration.MigrationFn != nil {
|
||||||
|
log.Infof("Applying custom migration '%v' "+
|
||||||
|
"(version %d) to schema version %d",
|
||||||
|
migration.Name, migration.Version,
|
||||||
|
migration.SchemaVersion)
|
||||||
|
|
||||||
|
err = migration.MigrationFn(tx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error applying "+
|
||||||
|
"migration '%v' (version %d) "+
|
||||||
|
"to schema version %d: %w",
|
||||||
|
migration.Name,
|
||||||
|
migration.Version,
|
||||||
|
migration.SchemaVersion, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Migration '%v' (version %d) "+
|
||||||
|
"applied ", migration.Name,
|
||||||
|
migration.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark the migration as complete by adding the version
|
||||||
|
// to the migration tracker table along with the current
|
||||||
|
// timestamp.
|
||||||
|
err = tx.SetMigration(ctx, sqlc.SetMigrationParams{
|
||||||
|
Version: int32(migration.Version),
|
||||||
|
MigrationTime: time.Now(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error setting migration "+
|
||||||
|
"version %d: %w", migration.Version,
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, func() {})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -2,8 +2,15 @@ package sqldb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang-migrate/migrate/v4"
|
||||||
|
"github.com/golang-migrate/migrate/v4/database"
|
||||||
|
pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
||||||
|
sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite"
|
||||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@ -152,3 +159,296 @@ func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, expected, invoices)
|
require.Equal(t, expected, invoices)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestCustomMigration tests that a custom in-code migrations are correctly
|
||||||
|
// executed during the migration process.
|
||||||
|
func TestCustomMigration(t *testing.T) {
|
||||||
|
var customMigrationLog []string
|
||||||
|
|
||||||
|
logMigration := func(name string) {
|
||||||
|
customMigrationLog = append(customMigrationLog, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some migrations to use for both the failure and success tests. Note
|
||||||
|
// that the migrations are not in order to test that they are executed
|
||||||
|
// in the correct order.
|
||||||
|
migrations := []MigrationConfig{
|
||||||
|
{
|
||||||
|
Name: "1",
|
||||||
|
Version: 1,
|
||||||
|
SchemaVersion: 1,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("1")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "2",
|
||||||
|
Version: 2,
|
||||||
|
SchemaVersion: 1,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("2")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "3",
|
||||||
|
Version: 3,
|
||||||
|
SchemaVersion: 2,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("3")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
migrations []MigrationConfig
|
||||||
|
expectedSuccess bool
|
||||||
|
expectedMigrationLog []string
|
||||||
|
expectedSchemaVersion int
|
||||||
|
expectedVersion int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success",
|
||||||
|
migrations: migrations,
|
||||||
|
expectedSuccess: true,
|
||||||
|
expectedMigrationLog: []string{"1", "2", "3"},
|
||||||
|
expectedSchemaVersion: 2,
|
||||||
|
expectedVersion: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unordered migrations",
|
||||||
|
migrations: append([]MigrationConfig{
|
||||||
|
{
|
||||||
|
Name: "4",
|
||||||
|
Version: 4,
|
||||||
|
SchemaVersion: 3,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("4")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, migrations...),
|
||||||
|
expectedSuccess: false,
|
||||||
|
expectedMigrationLog: nil,
|
||||||
|
expectedSchemaVersion: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "failure of migration 4",
|
||||||
|
migrations: append(migrations, MigrationConfig{
|
||||||
|
Name: "4",
|
||||||
|
Version: 4,
|
||||||
|
SchemaVersion: 3,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
return fmt.Errorf("migration 4 failed")
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
expectedSuccess: false,
|
||||||
|
expectedMigrationLog: []string{"1", "2", "3"},
|
||||||
|
// Since schema migration is a separate step we expect
|
||||||
|
// that migrating up to 3 succeeded.
|
||||||
|
expectedSchemaVersion: 3,
|
||||||
|
// We still remain on version 3 though.
|
||||||
|
expectedVersion: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success of migration 4",
|
||||||
|
migrations: append(migrations, MigrationConfig{
|
||||||
|
Name: "4",
|
||||||
|
Version: 4,
|
||||||
|
SchemaVersion: 3,
|
||||||
|
MigrationFn: func(*sqlc.Queries) error {
|
||||||
|
logMigration("4")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
expectedSuccess: true,
|
||||||
|
expectedMigrationLog: []string{"1", "2", "3", "4"},
|
||||||
|
expectedSchemaVersion: 3,
|
||||||
|
expectedVersion: 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxb := context.Background()
|
||||||
|
for _, test := range tests {
|
||||||
|
// checkSchemaVersion checks the database schema version against
|
||||||
|
// the expected version.
|
||||||
|
getSchemaVersion := func(t *testing.T,
|
||||||
|
driver database.Driver, dbName string) int {
|
||||||
|
|
||||||
|
sqlMigrate, err := migrate.NewWithInstance(
|
||||||
|
"migrations", nil, dbName, driver,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
version, _, err := sqlMigrate.Version()
|
||||||
|
if err != migrate.ErrNilVersion {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(version)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("SQLite "+test.name, func(t *testing.T) {
|
||||||
|
customMigrationLog = nil
|
||||||
|
|
||||||
|
// First instantiate the database and run the migrations
|
||||||
|
// including the custom migrations.
|
||||||
|
t.Logf("Creating new SQLite DB for testing migrations")
|
||||||
|
|
||||||
|
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
||||||
|
var (
|
||||||
|
db *SqliteStore
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
// Run the migration 3 times to test that the migrations
|
||||||
|
// are idempotent.
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
db, err = NewSqliteStore(&SqliteConfig{
|
||||||
|
SkipMigrations: false,
|
||||||
|
}, dbFileName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dbToCleanup := db.DB
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(
|
||||||
|
t, dbToCleanup.Close(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
err = db.ApplyAllMigrations(
|
||||||
|
ctxb, test.migrations,
|
||||||
|
)
|
||||||
|
if test.expectedSuccess {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Also repoen the DB without migrations
|
||||||
|
// so we can read versions.
|
||||||
|
db, err = NewSqliteStore(&SqliteConfig{
|
||||||
|
SkipMigrations: true,
|
||||||
|
}, dbFileName)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t,
|
||||||
|
test.expectedMigrationLog,
|
||||||
|
customMigrationLog,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create the migration executor to be able to
|
||||||
|
// query the current schema version.
|
||||||
|
driver, err := sqlite_migrate.WithInstance(
|
||||||
|
db.DB, &sqlite_migrate.Config{},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, test.expectedSchemaVersion,
|
||||||
|
getSchemaVersion(t, driver, ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check the migraton version in the database.
|
||||||
|
version, err := db.GetDatabaseVersion(ctxb)
|
||||||
|
if test.expectedSchemaVersion != 0 {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Equal(t, sql.ErrNoRows, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, test.expectedVersion, int(version),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Postgres "+test.name, func(t *testing.T) {
|
||||||
|
customMigrationLog = nil
|
||||||
|
|
||||||
|
// First create a temporary Postgres database to run
|
||||||
|
// the migrations on.
|
||||||
|
fixture := NewTestPgFixture(
|
||||||
|
t, DefaultPostgresFixtureLifetime,
|
||||||
|
)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
fixture.TearDown(t)
|
||||||
|
})
|
||||||
|
|
||||||
|
dbName := randomDBName(t)
|
||||||
|
|
||||||
|
// Next instantiate the database and run the migrations
|
||||||
|
// including the custom migrations.
|
||||||
|
t.Logf("Creating new Postgres DB '%s' for testing "+
|
||||||
|
"migrations", dbName)
|
||||||
|
|
||||||
|
_, err := fixture.db.ExecContext(
|
||||||
|
context.Background(), "CREATE DATABASE "+dbName,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := fixture.GetConfig(dbName)
|
||||||
|
var db *PostgresStore
|
||||||
|
|
||||||
|
// Run the migration 3 times to test that the migrations
|
||||||
|
// are idempotent.
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
cfg.SkipMigrations = false
|
||||||
|
db, err = NewPostgresStore(cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = db.ApplyAllMigrations(
|
||||||
|
ctxb, test.migrations,
|
||||||
|
)
|
||||||
|
if test.expectedSuccess {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Also repoen the DB without migrations
|
||||||
|
// so we can read versions.
|
||||||
|
cfg.SkipMigrations = true
|
||||||
|
db, err = NewPostgresStore(cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t,
|
||||||
|
test.expectedMigrationLog,
|
||||||
|
customMigrationLog,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create the migration executor to be able to
|
||||||
|
// query the current version.
|
||||||
|
driver, err := pgx_migrate.WithInstance(
|
||||||
|
db.DB, &pgx_migrate.Config{},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, test.expectedSchemaVersion,
|
||||||
|
getSchemaVersion(t, driver, ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check the migraton version in the database.
|
||||||
|
version, err := db.GetDatabaseVersion(ctxb)
|
||||||
|
if test.expectedSchemaVersion != 0 {
|
||||||
|
require.NoError(t, err)
|
||||||
|
} else {
|
||||||
|
require.Equal(t, sql.ErrNoRows, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(
|
||||||
|
t, test.expectedVersion, int(version),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,7 +2,15 @@
|
|||||||
|
|
||||||
package sqldb
|
package sqldb
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Make sure SqliteStore implements the DB interface.
|
||||||
|
_ DB = (*SqliteStore)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
// SqliteStore is a database store implementation that uses a sqlite backend.
|
// SqliteStore is a database store implementation that uses a sqlite backend.
|
||||||
type SqliteStore struct {
|
type SqliteStore struct {
|
||||||
@ -16,3 +24,17 @@ type SqliteStore struct {
|
|||||||
func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
||||||
return nil, fmt.Errorf("SQLite backend not supported in WebAssembly")
|
return nil, fmt.Errorf("SQLite backend not supported in WebAssembly")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBaseDB returns the underlying BaseDB instance for the SQLite store.
|
||||||
|
// It is a trivial helper method to comply with the sqldb.DB interface.
|
||||||
|
func (s *SqliteStore) GetBaseDB() *BaseDB {
|
||||||
|
return s.BaseDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyAllMigrations applies both the SQLC and custom in-code migrations to
|
||||||
|
// the SQLite database.
|
||||||
|
func (s *SqliteStore) ApplyAllMigrations(context.Context,
|
||||||
|
[]MigrationConfig) error {
|
||||||
|
|
||||||
|
return fmt.Errorf("SQLite backend not supported in WebAssembly")
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package sqldb
|
package sqldb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -28,10 +29,15 @@ var (
|
|||||||
// has some differences.
|
// has some differences.
|
||||||
postgresSchemaReplacements = map[string]string{
|
postgresSchemaReplacements = map[string]string{
|
||||||
"BLOB": "BYTEA",
|
"BLOB": "BYTEA",
|
||||||
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
|
"INTEGER PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
|
||||||
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
|
|
||||||
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
|
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make sure PostgresStore implements the MigrationExecutor interface.
|
||||||
|
_ MigrationExecutor = (*PostgresStore)(nil)
|
||||||
|
|
||||||
|
// Make sure PostgresStore implements the DB interface.
|
||||||
|
_ DB = (*PostgresStore)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// replacePasswordInDSN takes a DSN string and returns it with the password
|
// replacePasswordInDSN takes a DSN string and returns it with the password
|
||||||
@ -92,40 +98,64 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
|
|||||||
}
|
}
|
||||||
log.Infof("Using SQL database '%s'", sanitizedDSN)
|
log.Infof("Using SQL database '%s'", sanitizedDSN)
|
||||||
|
|
||||||
rawDB, err := sql.Open("pgx", cfg.Dsn)
|
db, err := sql.Open("pgx", cfg.Dsn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure the migration tracker table exists before running migrations.
|
||||||
|
// This table tracks migration progress and ensures compatibility with
|
||||||
|
// SQLC query generation. If the table is already created by an SQLC
|
||||||
|
// migration, this operation becomes a no-op.
|
||||||
|
migrationTrackerSQL := `
|
||||||
|
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||||
|
version INTEGER UNIQUE NOT NULL,
|
||||||
|
migration_time TIMESTAMP NOT NULL
|
||||||
|
);`
|
||||||
|
|
||||||
|
_, err = db.Exec(migrationTrackerSQL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating migration tracker: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
maxConns := defaultMaxConns
|
maxConns := defaultMaxConns
|
||||||
if cfg.MaxConnections > 0 {
|
if cfg.MaxConnections > 0 {
|
||||||
maxConns = cfg.MaxConnections
|
maxConns = cfg.MaxConnections
|
||||||
}
|
}
|
||||||
|
|
||||||
rawDB.SetMaxOpenConns(maxConns)
|
db.SetMaxOpenConns(maxConns)
|
||||||
rawDB.SetMaxIdleConns(maxConns)
|
db.SetMaxIdleConns(maxConns)
|
||||||
rawDB.SetConnMaxLifetime(connIdleLifetime)
|
db.SetConnMaxLifetime(connIdleLifetime)
|
||||||
|
|
||||||
queries := sqlc.New(rawDB)
|
queries := sqlc.New(db)
|
||||||
|
|
||||||
s := &PostgresStore{
|
return &PostgresStore{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
BaseDB: &BaseDB{
|
BaseDB: &BaseDB{
|
||||||
DB: rawDB,
|
DB: db,
|
||||||
Queries: queries,
|
Queries: queries,
|
||||||
},
|
},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBaseDB returns the underlying BaseDB instance for the Postgres store.
|
||||||
|
// It is a trivial helper method to comply with the sqldb.DB interface.
|
||||||
|
func (s *PostgresStore) GetBaseDB() *BaseDB {
|
||||||
|
return s.BaseDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyAllMigrations applies both the SQLC and custom in-code migrations to the
|
||||||
|
// Postgres database.
|
||||||
|
func (s *PostgresStore) ApplyAllMigrations(ctx context.Context,
|
||||||
|
migrations []MigrationConfig) error {
|
||||||
|
|
||||||
// Execute migrations unless configured to skip them.
|
// Execute migrations unless configured to skip them.
|
||||||
if !cfg.SkipMigrations {
|
if s.cfg.SkipMigrations {
|
||||||
err := s.ExecuteMigrations(TargetLatest)
|
return nil
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error executing migrations: %w",
|
|
||||||
err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s, nil
|
return ApplyMigrations(ctx, s.BaseDB, s, migrations)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteMigrations runs migrations for the Postgres database, depending on the
|
// ExecuteMigrations runs migrations for the Postgres database, depending on the
|
||||||
|
@ -59,7 +59,7 @@ func NewTestPgFixture(t *testing.T, expiry time.Duration) *TestPgFixture {
|
|||||||
"postgres",
|
"postgres",
|
||||||
"-c", "log_statement=all",
|
"-c", "log_statement=all",
|
||||||
"-c", "log_destination=stderr",
|
"-c", "log_destination=stderr",
|
||||||
"-c", "max_connections=1000",
|
"-c", "max_connections=5000",
|
||||||
},
|
},
|
||||||
}, func(config *docker.HostConfig) {
|
}, func(config *docker.HostConfig) {
|
||||||
// Set AutoRemove to true so that stopped container goes away
|
// Set AutoRemove to true so that stopped container goes away
|
||||||
@ -151,6 +151,10 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore {
|
|||||||
store, err := NewPostgresStore(cfg)
|
store, err := NewPostgresStore(cfg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, store.ApplyAllMigrations(
|
||||||
|
context.Background(), GetMigrations()),
|
||||||
|
)
|
||||||
|
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -235,6 +235,35 @@ func (q *Queries) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, err
|
|||||||
return invoice_id, err
|
return invoice_id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const insertAMPSubInvoice = `-- name: InsertAMPSubInvoice :exec
|
||||||
|
INSERT INTO amp_sub_invoices (
|
||||||
|
set_id, state, created_at, settled_at, settle_index, invoice_id
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6
|
||||||
|
)
|
||||||
|
`
|
||||||
|
|
||||||
|
type InsertAMPSubInvoiceParams struct {
|
||||||
|
SetID []byte
|
||||||
|
State int16
|
||||||
|
CreatedAt time.Time
|
||||||
|
SettledAt sql.NullTime
|
||||||
|
SettleIndex sql.NullInt64
|
||||||
|
InvoiceID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, insertAMPSubInvoice,
|
||||||
|
arg.SetID,
|
||||||
|
arg.State,
|
||||||
|
arg.CreatedAt,
|
||||||
|
arg.SettledAt,
|
||||||
|
arg.SettleIndex,
|
||||||
|
arg.InvoiceID,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const insertAMPSubInvoiceHTLC = `-- name: InsertAMPSubInvoiceHTLC :exec
|
const insertAMPSubInvoiceHTLC = `-- name: InsertAMPSubInvoiceHTLC :exec
|
||||||
INSERT INTO amp_sub_invoice_htlcs (
|
INSERT INTO amp_sub_invoice_htlcs (
|
||||||
invoice_id, set_id, htlc_id, root_share, child_index, hash, preimage
|
invoice_id, set_id, htlc_id, root_share, child_index, hash, preimage
|
||||||
|
@ -11,6 +11,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const clearKVInvoiceHashIndex = `-- name: ClearKVInvoiceHashIndex :exec
|
||||||
|
DELETE FROM invoice_payment_hashes
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) ClearKVInvoiceHashIndex(ctx context.Context) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, clearKVInvoiceHashIndex)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const deleteCanceledInvoices = `-- name: DeleteCanceledInvoices :execresult
|
const deleteCanceledInvoices = `-- name: DeleteCanceledInvoices :execresult
|
||||||
DELETE
|
DELETE
|
||||||
FROM invoices
|
FROM invoices
|
||||||
@ -182,11 +191,8 @@ WHERE (
|
|||||||
i.hash = $3 OR
|
i.hash = $3 OR
|
||||||
$3 IS NULL
|
$3 IS NULL
|
||||||
) AND (
|
) AND (
|
||||||
i.preimage = $4 OR
|
i.payment_addr = $4 OR
|
||||||
$4 IS NULL
|
$4 IS NULL
|
||||||
) AND (
|
|
||||||
i.payment_addr = $5 OR
|
|
||||||
$5 IS NULL
|
|
||||||
)
|
)
|
||||||
GROUP BY i.id
|
GROUP BY i.id
|
||||||
LIMIT 2
|
LIMIT 2
|
||||||
@ -196,7 +202,6 @@ type GetInvoiceParams struct {
|
|||||||
SetID []byte
|
SetID []byte
|
||||||
AddIndex sql.NullInt64
|
AddIndex sql.NullInt64
|
||||||
Hash []byte
|
Hash []byte
|
||||||
Preimage []byte
|
|
||||||
PaymentAddr []byte
|
PaymentAddr []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,7 +213,6 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
|
|||||||
arg.SetID,
|
arg.SetID,
|
||||||
arg.AddIndex,
|
arg.AddIndex,
|
||||||
arg.Hash,
|
arg.Hash,
|
||||||
arg.Preimage,
|
|
||||||
arg.PaymentAddr,
|
arg.PaymentAddr,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -251,6 +255,38 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getInvoiceByHash = `-- name: GetInvoiceByHash :one
|
||||||
|
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
|
||||||
|
FROM invoices i
|
||||||
|
WHERE i.hash = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getInvoiceByHash, hash)
|
||||||
|
var i Invoice
|
||||||
|
err := row.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Hash,
|
||||||
|
&i.Preimage,
|
||||||
|
&i.SettleIndex,
|
||||||
|
&i.SettledAt,
|
||||||
|
&i.Memo,
|
||||||
|
&i.AmountMsat,
|
||||||
|
&i.CltvDelta,
|
||||||
|
&i.Expiry,
|
||||||
|
&i.PaymentAddr,
|
||||||
|
&i.PaymentRequest,
|
||||||
|
&i.PaymentRequestHash,
|
||||||
|
&i.State,
|
||||||
|
&i.AmountPaidMsat,
|
||||||
|
&i.IsAmp,
|
||||||
|
&i.IsHodl,
|
||||||
|
&i.IsKeysend,
|
||||||
|
&i.CreatedAt,
|
||||||
|
)
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many
|
const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many
|
||||||
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
|
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
|
||||||
FROM invoices i
|
FROM invoices i
|
||||||
@ -405,6 +441,19 @@ func (q *Queries) GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]Invoi
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getKVInvoicePaymentHashByAddIndex = `-- name: GetKVInvoicePaymentHashByAddIndex :one
|
||||||
|
SELECT hash
|
||||||
|
FROM invoice_payment_hashes
|
||||||
|
WHERE add_index = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getKVInvoicePaymentHashByAddIndex, addIndex)
|
||||||
|
var hash []byte
|
||||||
|
err := row.Scan(&hash)
|
||||||
|
return hash, err
|
||||||
|
}
|
||||||
|
|
||||||
const insertInvoice = `-- name: InsertInvoice :one
|
const insertInvoice = `-- name: InsertInvoice :one
|
||||||
INSERT INTO invoices (
|
INSERT INTO invoices (
|
||||||
hash, preimage, memo, amount_msat, cltv_delta, expiry, payment_addr,
|
hash, preimage, memo, amount_msat, cltv_delta, expiry, payment_addr,
|
||||||
@ -533,6 +582,79 @@ func (q *Queries) InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const insertKVInvoiceKeyAndAddIndex = `-- name: InsertKVInvoiceKeyAndAddIndex :exec
|
||||||
|
INSERT INTO invoice_payment_hashes (
|
||||||
|
id, add_index
|
||||||
|
) VALUES (
|
||||||
|
$1, $2
|
||||||
|
)
|
||||||
|
`
|
||||||
|
|
||||||
|
type InsertKVInvoiceKeyAndAddIndexParams struct {
|
||||||
|
ID int64
|
||||||
|
AddIndex int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, insertKVInvoiceKeyAndAddIndex, arg.ID, arg.AddIndex)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const insertMigratedInvoice = `-- name: InsertMigratedInvoice :one
|
||||||
|
INSERT INTO invoices (
|
||||||
|
hash, preimage, settle_index, settled_at, memo, amount_msat, cltv_delta,
|
||||||
|
expiry, payment_addr, payment_request, payment_request_hash, state,
|
||||||
|
amount_paid_msat, is_amp, is_hodl, is_keysend, created_at
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
|
||||||
|
) RETURNING id
|
||||||
|
`
|
||||||
|
|
||||||
|
type InsertMigratedInvoiceParams struct {
|
||||||
|
Hash []byte
|
||||||
|
Preimage []byte
|
||||||
|
SettleIndex sql.NullInt64
|
||||||
|
SettledAt sql.NullTime
|
||||||
|
Memo sql.NullString
|
||||||
|
AmountMsat int64
|
||||||
|
CltvDelta sql.NullInt32
|
||||||
|
Expiry int32
|
||||||
|
PaymentAddr []byte
|
||||||
|
PaymentRequest sql.NullString
|
||||||
|
PaymentRequestHash []byte
|
||||||
|
State int16
|
||||||
|
AmountPaidMsat int64
|
||||||
|
IsAmp bool
|
||||||
|
IsHodl bool
|
||||||
|
IsKeysend bool
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, insertMigratedInvoice,
|
||||||
|
arg.Hash,
|
||||||
|
arg.Preimage,
|
||||||
|
arg.SettleIndex,
|
||||||
|
arg.SettledAt,
|
||||||
|
arg.Memo,
|
||||||
|
arg.AmountMsat,
|
||||||
|
arg.CltvDelta,
|
||||||
|
arg.Expiry,
|
||||||
|
arg.PaymentAddr,
|
||||||
|
arg.PaymentRequest,
|
||||||
|
arg.PaymentRequestHash,
|
||||||
|
arg.State,
|
||||||
|
arg.AmountPaidMsat,
|
||||||
|
arg.IsAmp,
|
||||||
|
arg.IsHodl,
|
||||||
|
arg.IsKeysend,
|
||||||
|
arg.CreatedAt,
|
||||||
|
)
|
||||||
|
var id int64
|
||||||
|
err := row.Scan(&id)
|
||||||
|
return id, err
|
||||||
|
}
|
||||||
|
|
||||||
const nextInvoiceSettleIndex = `-- name: NextInvoiceSettleIndex :one
|
const nextInvoiceSettleIndex = `-- name: NextInvoiceSettleIndex :one
|
||||||
UPDATE invoice_sequences SET current_value = current_value + 1
|
UPDATE invoice_sequences SET current_value = current_value + 1
|
||||||
WHERE name = 'settle_index'
|
WHERE name = 'settle_index'
|
||||||
@ -546,6 +668,22 @@ func (q *Queries) NextInvoiceSettleIndex(ctx context.Context) (int64, error) {
|
|||||||
return current_value, err
|
return current_value, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const setKVInvoicePaymentHash = `-- name: SetKVInvoicePaymentHash :exec
|
||||||
|
UPDATE invoice_payment_hashes
|
||||||
|
SET hash = $2
|
||||||
|
WHERE id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
type SetKVInvoicePaymentHashParams struct {
|
||||||
|
ID int64
|
||||||
|
Hash []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, setKVInvoicePaymentHash, arg.ID, arg.Hash)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
const updateInvoiceAmountPaid = `-- name: UpdateInvoiceAmountPaid :execresult
|
const updateInvoiceAmountPaid = `-- name: UpdateInvoiceAmountPaid :execresult
|
||||||
UPDATE invoices
|
UPDATE invoices
|
||||||
SET amount_paid_msat = $2
|
SET amount_paid_msat = $2
|
||||||
|
60
sqldb/sqlc/migration.sql.go
Normal file
60
sqldb/sqlc/migration.sql.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
// Code generated by sqlc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// sqlc v1.25.0
|
||||||
|
// source: migration.sql
|
||||||
|
|
||||||
|
package sqlc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const getDatabaseVersion = `-- name: GetDatabaseVersion :one
|
||||||
|
SELECT
|
||||||
|
version
|
||||||
|
FROM
|
||||||
|
migration_tracker
|
||||||
|
ORDER BY
|
||||||
|
version DESC
|
||||||
|
LIMIT 1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetDatabaseVersion(ctx context.Context) (int32, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getDatabaseVersion)
|
||||||
|
var version int32
|
||||||
|
err := row.Scan(&version)
|
||||||
|
return version, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getMigration = `-- name: GetMigration :one
|
||||||
|
SELECT
|
||||||
|
migration_time
|
||||||
|
FROM
|
||||||
|
migration_tracker
|
||||||
|
WHERE
|
||||||
|
version = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetMigration(ctx context.Context, version int32) (time.Time, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getMigration, version)
|
||||||
|
var migration_time time.Time
|
||||||
|
err := row.Scan(&migration_time)
|
||||||
|
return migration_time, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const setMigration = `-- name: SetMigration :exec
|
||||||
|
INSERT INTO
|
||||||
|
migration_tracker (version, migration_time)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
`
|
||||||
|
|
||||||
|
type SetMigrationParams struct {
|
||||||
|
Version int32
|
||||||
|
MigrationTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) SetMigration(ctx context.Context, arg SetMigrationParams) error {
|
||||||
|
_, err := q.db.ExecContext(ctx, setMigration, arg.Version, arg.MigrationTime)
|
||||||
|
return err
|
||||||
|
}
|
@ -11,7 +11,7 @@ INSERT INTO invoice_sequences(name, current_value) VALUES ('settle_index', 0);
|
|||||||
-- invoices table contains all the information shared by all the invoice types.
|
-- invoices table contains all the information shared by all the invoice types.
|
||||||
CREATE TABLE IF NOT EXISTS invoices (
|
CREATE TABLE IF NOT EXISTS invoices (
|
||||||
-- The id of the invoice. Translates to the AddIndex.
|
-- The id of the invoice. Translates to the AddIndex.
|
||||||
id BIGINT PRIMARY KEY,
|
id INTEGER PRIMARY KEY,
|
||||||
|
|
||||||
-- The hash for this invoice. The invoice hash will always identify that
|
-- The hash for this invoice. The invoice hash will always identify that
|
||||||
-- invoice.
|
-- invoice.
|
||||||
@ -102,7 +102,7 @@ CREATE INDEX IF NOT EXISTS invoice_feature_invoice_id_idx ON invoice_features(in
|
|||||||
CREATE TABLE IF NOT EXISTS invoice_htlcs (
|
CREATE TABLE IF NOT EXISTS invoice_htlcs (
|
||||||
-- The id for this htlc. Used in foreign keys instead of the
|
-- The id for this htlc. Used in foreign keys instead of the
|
||||||
-- htlc_id/chan_id combination.
|
-- htlc_id/chan_id combination.
|
||||||
id BIGINT PRIMARY KEY,
|
id INTEGER PRIMARY KEY,
|
||||||
|
|
||||||
-- Short chan id indicating the htlc's origin. uint64 stored as text.
|
-- Short chan id indicating the htlc's origin. uint64 stored as text.
|
||||||
chan_id TEXT NOT NULL,
|
chan_id TEXT NOT NULL,
|
||||||
|
@ -29,7 +29,7 @@ VALUES
|
|||||||
-- AMP sub invoices. This table can be used to create a historical view of what
|
-- AMP sub invoices. This table can be used to create a historical view of what
|
||||||
-- happened to the node's invoices.
|
-- happened to the node's invoices.
|
||||||
CREATE TABLE IF NOT EXISTS invoice_events (
|
CREATE TABLE IF NOT EXISTS invoice_events (
|
||||||
id BIGINT PRIMARY KEY,
|
id INTEGER PRIMARY KEY,
|
||||||
|
|
||||||
-- added_at is the timestamp when this event was added.
|
-- added_at is the timestamp when this event was added.
|
||||||
added_at TIMESTAMP NOT NULL,
|
added_at TIMESTAMP NOT NULL,
|
||||||
|
1
sqldb/sqlc/migrations/000005_migration_tracker.down.sql
Normal file
1
sqldb/sqlc/migrations/000005_migration_tracker.down.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
DROP TABLE IF EXISTS migration_tracker;
|
17
sqldb/sqlc/migrations/000005_migration_tracker.up.sql
Normal file
17
sqldb/sqlc/migrations/000005_migration_tracker.up.sql
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
-- The migration_tracker table keeps track of migrations that have been applied
|
||||||
|
-- to the database. This table ensures that migrations are idempotent and are
|
||||||
|
-- only run once. It tracks a global database version that encompasses both
|
||||||
|
-- schema migrations handled by golang-migrate and custom in-code migrations
|
||||||
|
-- for more complex data conversions that cannot be expressed in pure SQL.
|
||||||
|
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||||
|
-- version is the global version of the migration. Note that we
|
||||||
|
-- intentionally don't set it as PRIMARY KEY as it'd auto increment on
|
||||||
|
-- SQLite and our sqlc workflow will replace it with an auto
|
||||||
|
-- incrementing SERIAL on Postgres too. UNIQUE achieves the same effect
|
||||||
|
-- without the auto increment.
|
||||||
|
version INTEGER UNIQUE NOT NULL,
|
||||||
|
|
||||||
|
-- migration_time is the timestamp at which the migration was run.
|
||||||
|
migration_time TIMESTAMP NOT NULL
|
||||||
|
);
|
||||||
|
|
1
sqldb/sqlc/migrations/000006_invoice_migration.down.sql
Normal file
1
sqldb/sqlc/migrations/000006_invoice_migration.down.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
DROP TABLE IF EXISTS invoice_payment_hashes;
|
17
sqldb/sqlc/migrations/000006_invoice_migration.up.sql
Normal file
17
sqldb/sqlc/migrations/000006_invoice_migration.up.sql
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
-- invoice_payment_hashes table contains the hash of the invoices. This table
|
||||||
|
-- is used during KV to SQL invoice migration as in our KV representation we
|
||||||
|
-- don't have a mapping from hash to add index.
|
||||||
|
CREATE TABLE IF NOT EXISTS invoice_payment_hashes (
|
||||||
|
-- id represents is the key of the invoice in the KV store.
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
|
||||||
|
-- add_index is the KV add index of the invoice.
|
||||||
|
add_index BIGINT NOT NULL,
|
||||||
|
|
||||||
|
-- hash is the payment hash for this invoice.
|
||||||
|
hash BLOB
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Create an indexes on the add_index and hash columns to speed up lookups.
|
||||||
|
CREATE INDEX IF NOT EXISTS invoice_payment_hashes_add_index_idx ON invoice_payment_hashes(add_index);
|
||||||
|
CREATE INDEX IF NOT EXISTS invoice_payment_hashes_hash_idx ON invoice_payment_hashes(hash);
|
@ -58,7 +58,7 @@ type InvoiceEvent struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type InvoiceEventType struct {
|
type InvoiceEventType struct {
|
||||||
ID int32
|
ID int64
|
||||||
Description string
|
Description string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +87,18 @@ type InvoiceHtlcCustomRecord struct {
|
|||||||
HtlcID int64
|
HtlcID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InvoicePaymentHash struct {
|
||||||
|
ID int64
|
||||||
|
AddIndex int64
|
||||||
|
Hash []byte
|
||||||
|
}
|
||||||
|
|
||||||
type InvoiceSequence struct {
|
type InvoiceSequence struct {
|
||||||
Name string
|
Name string
|
||||||
CurrentValue int64
|
CurrentValue int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MigrationTracker struct {
|
||||||
|
Version int32
|
||||||
|
MigrationTime time.Time
|
||||||
|
}
|
||||||
|
@ -7,9 +7,11 @@ package sqlc
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Querier interface {
|
type Querier interface {
|
||||||
|
ClearKVInvoiceHashIndex(ctx context.Context) error
|
||||||
DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
|
DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
|
||||||
DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error)
|
DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error)
|
||||||
FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error)
|
FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error)
|
||||||
@ -17,19 +19,26 @@ type Querier interface {
|
|||||||
FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error)
|
FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error)
|
||||||
FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error)
|
FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error)
|
||||||
GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error)
|
GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error)
|
||||||
|
GetDatabaseVersion(ctx context.Context) (int32, error)
|
||||||
// This method may return more than one invoice if filter using multiple fields
|
// This method may return more than one invoice if filter using multiple fields
|
||||||
// from different invoices. It is the caller's responsibility to ensure that
|
// from different invoices. It is the caller's responsibility to ensure that
|
||||||
// we bubble up an error in those cases.
|
// we bubble up an error in those cases.
|
||||||
GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error)
|
GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error)
|
||||||
|
GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error)
|
||||||
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error)
|
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error)
|
||||||
GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error)
|
GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error)
|
||||||
GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error)
|
GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error)
|
||||||
GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error)
|
GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error)
|
||||||
|
GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error)
|
||||||
|
GetMigration(ctx context.Context, version int32) (time.Time, error)
|
||||||
|
InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error
|
||||||
InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error
|
InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error
|
||||||
InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error)
|
InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error)
|
||||||
InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error
|
InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error
|
||||||
InsertInvoiceHTLC(ctx context.Context, arg InsertInvoiceHTLCParams) (int64, error)
|
InsertInvoiceHTLC(ctx context.Context, arg InsertInvoiceHTLCParams) (int64, error)
|
||||||
InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertInvoiceHTLCCustomRecordParams) error
|
InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertInvoiceHTLCCustomRecordParams) error
|
||||||
|
InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error
|
||||||
|
InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error)
|
||||||
NextInvoiceSettleIndex(ctx context.Context) (int64, error)
|
NextInvoiceSettleIndex(ctx context.Context) (int64, error)
|
||||||
OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error
|
OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error
|
||||||
OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error
|
OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error
|
||||||
@ -37,6 +46,8 @@ type Querier interface {
|
|||||||
OnInvoiceCanceled(ctx context.Context, arg OnInvoiceCanceledParams) error
|
OnInvoiceCanceled(ctx context.Context, arg OnInvoiceCanceledParams) error
|
||||||
OnInvoiceCreated(ctx context.Context, arg OnInvoiceCreatedParams) error
|
OnInvoiceCreated(ctx context.Context, arg OnInvoiceCreatedParams) error
|
||||||
OnInvoiceSettled(ctx context.Context, arg OnInvoiceSettledParams) error
|
OnInvoiceSettled(ctx context.Context, arg OnInvoiceSettledParams) error
|
||||||
|
SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error
|
||||||
|
SetMigration(ctx context.Context, arg SetMigrationParams) error
|
||||||
UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result, error)
|
UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result, error)
|
||||||
UpdateAMPSubInvoiceState(ctx context.Context, arg UpdateAMPSubInvoiceStateParams) error
|
UpdateAMPSubInvoiceState(ctx context.Context, arg UpdateAMPSubInvoiceStateParams) error
|
||||||
UpdateInvoiceAmountPaid(ctx context.Context, arg UpdateInvoiceAmountPaidParams) (sql.Result, error)
|
UpdateInvoiceAmountPaid(ctx context.Context, arg UpdateInvoiceAmountPaidParams) (sql.Result, error)
|
||||||
|
@ -65,3 +65,11 @@ SET preimage = $5
|
|||||||
WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = (
|
WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = (
|
||||||
SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4
|
SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- name: InsertAMPSubInvoice :exec
|
||||||
|
INSERT INTO amp_sub_invoices (
|
||||||
|
set_id, state, created_at, settled_at, settle_index, invoice_id
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6
|
||||||
|
);
|
||||||
|
|
||||||
|
@ -7,6 +7,16 @@ INSERT INTO invoices (
|
|||||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
|
||||||
) RETURNING id;
|
) RETURNING id;
|
||||||
|
|
||||||
|
-- name: InsertMigratedInvoice :one
|
||||||
|
INSERT INTO invoices (
|
||||||
|
hash, preimage, settle_index, settled_at, memo, amount_msat, cltv_delta,
|
||||||
|
expiry, payment_addr, payment_request, payment_request_hash, state,
|
||||||
|
amount_paid_msat, is_amp, is_hodl, is_keysend, created_at
|
||||||
|
) VALUES (
|
||||||
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
|
||||||
|
) RETURNING id;
|
||||||
|
|
||||||
|
|
||||||
-- name: InsertInvoiceFeature :exec
|
-- name: InsertInvoiceFeature :exec
|
||||||
INSERT INTO invoice_features (
|
INSERT INTO invoice_features (
|
||||||
invoice_id, feature
|
invoice_id, feature
|
||||||
@ -37,9 +47,6 @@ WHERE (
|
|||||||
) AND (
|
) AND (
|
||||||
i.hash = sqlc.narg('hash') OR
|
i.hash = sqlc.narg('hash') OR
|
||||||
sqlc.narg('hash') IS NULL
|
sqlc.narg('hash') IS NULL
|
||||||
) AND (
|
|
||||||
i.preimage = sqlc.narg('preimage') OR
|
|
||||||
sqlc.narg('preimage') IS NULL
|
|
||||||
) AND (
|
) AND (
|
||||||
i.payment_addr = sqlc.narg('payment_addr') OR
|
i.payment_addr = sqlc.narg('payment_addr') OR
|
||||||
sqlc.narg('payment_addr') IS NULL
|
sqlc.narg('payment_addr') IS NULL
|
||||||
@ -47,6 +54,11 @@ WHERE (
|
|||||||
GROUP BY i.id
|
GROUP BY i.id
|
||||||
LIMIT 2;
|
LIMIT 2;
|
||||||
|
|
||||||
|
-- name: GetInvoiceByHash :one
|
||||||
|
SELECT i.*
|
||||||
|
FROM invoices i
|
||||||
|
WHERE i.hash = $1;
|
||||||
|
|
||||||
-- name: GetInvoiceBySetID :many
|
-- name: GetInvoiceBySetID :many
|
||||||
SELECT i.*
|
SELECT i.*
|
||||||
FROM invoices i
|
FROM invoices i
|
||||||
@ -169,3 +181,23 @@ INSERT INTO invoice_htlc_custom_records (
|
|||||||
SELECT ihcr.htlc_id, key, value
|
SELECT ihcr.htlc_id, key, value
|
||||||
FROM invoice_htlcs ih JOIN invoice_htlc_custom_records ihcr ON ih.id=ihcr.htlc_id
|
FROM invoice_htlcs ih JOIN invoice_htlc_custom_records ihcr ON ih.id=ihcr.htlc_id
|
||||||
WHERE ih.invoice_id = $1;
|
WHERE ih.invoice_id = $1;
|
||||||
|
|
||||||
|
-- name: InsertKVInvoiceKeyAndAddIndex :exec
|
||||||
|
INSERT INTO invoice_payment_hashes (
|
||||||
|
id, add_index
|
||||||
|
) VALUES (
|
||||||
|
$1, $2
|
||||||
|
);
|
||||||
|
|
||||||
|
-- name: SetKVInvoicePaymentHash :exec
|
||||||
|
UPDATE invoice_payment_hashes
|
||||||
|
SET hash = $2
|
||||||
|
WHERE id = $1;
|
||||||
|
|
||||||
|
-- name: GetKVInvoicePaymentHashByAddIndex :one
|
||||||
|
SELECT hash
|
||||||
|
FROM invoice_payment_hashes
|
||||||
|
WHERE add_index = $1;
|
||||||
|
|
||||||
|
-- name: ClearKVInvoiceHashIndex :exec
|
||||||
|
DELETE FROM invoice_payment_hashes;
|
||||||
|
21
sqldb/sqlc/queries/migration.sql
Normal file
21
sqldb/sqlc/queries/migration.sql
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
-- name: SetMigration :exec
|
||||||
|
INSERT INTO
|
||||||
|
migration_tracker (version, migration_time)
|
||||||
|
VALUES ($1, $2);
|
||||||
|
|
||||||
|
-- name: GetMigration :one
|
||||||
|
SELECT
|
||||||
|
migration_time
|
||||||
|
FROM
|
||||||
|
migration_tracker
|
||||||
|
WHERE
|
||||||
|
version = $1;
|
||||||
|
|
||||||
|
-- name: GetDatabaseVersion :one
|
||||||
|
SELECT
|
||||||
|
version
|
||||||
|
FROM
|
||||||
|
migration_tracker
|
||||||
|
ORDER BY
|
||||||
|
version DESC
|
||||||
|
LIMIT 1;
|
@ -3,6 +3,7 @@
|
|||||||
package sqldb
|
package sqldb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -27,13 +28,16 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// sqliteSchemaReplacements is a map of schema strings that need to be
|
// sqliteSchemaReplacements maps schema strings to their SQLite
|
||||||
// replaced for sqlite. This is needed because sqlite doesn't directly
|
// compatible replacements. Currently, no replacements are needed as our
|
||||||
// support the BIGINT type for primary keys, so we need to replace it
|
// SQL schema definition files are designed for SQLite compatibility.
|
||||||
// with INTEGER.
|
sqliteSchemaReplacements = map[string]string{}
|
||||||
sqliteSchemaReplacements = map[string]string{
|
|
||||||
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
|
// Make sure SqliteStore implements the MigrationExecutor interface.
|
||||||
}
|
_ MigrationExecutor = (*SqliteStore)(nil)
|
||||||
|
|
||||||
|
// Make sure SqliteStore implements the DB interface.
|
||||||
|
_ DB = (*SqliteStore)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
// SqliteStore is a database store implementation that uses a sqlite backend.
|
// SqliteStore is a database store implementation that uses a sqlite backend.
|
||||||
@ -102,6 +106,23 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create the migration tracker table before starting migrations to
|
||||||
|
// ensure it can be used to track migration progress. Note that a
|
||||||
|
// corresponding SQLC migration also creates this table, making this
|
||||||
|
// operation a no-op in that context. Its purpose is to ensure
|
||||||
|
// compatibility with SQLC query generation.
|
||||||
|
migrationTrackerSQL := `
|
||||||
|
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||||
|
version INTEGER UNIQUE NOT NULL,
|
||||||
|
migration_time TIMESTAMP NOT NULL
|
||||||
|
);`
|
||||||
|
|
||||||
|
_, err = db.Exec(migrationTrackerSQL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating migration tracker: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
db.SetMaxOpenConns(defaultMaxConns)
|
db.SetMaxOpenConns(defaultMaxConns)
|
||||||
db.SetMaxIdleConns(defaultMaxConns)
|
db.SetMaxIdleConns(defaultMaxConns)
|
||||||
db.SetConnMaxLifetime(connIdleLifetime)
|
db.SetConnMaxLifetime(connIdleLifetime)
|
||||||
@ -115,18 +136,28 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute migrations unless configured to skip them.
|
|
||||||
if !cfg.SkipMigrations {
|
|
||||||
if err := s.ExecuteMigrations(TargetLatest); err != nil {
|
|
||||||
return nil, fmt.Errorf("error executing migrations: "+
|
|
||||||
"%w", err)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBaseDB returns the underlying BaseDB instance for the SQLite store.
|
||||||
|
// It is a trivial helper method to comply with the sqldb.DB interface.
|
||||||
|
func (s *SqliteStore) GetBaseDB() *BaseDB {
|
||||||
|
return s.BaseDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyAllMigrations applies both the SQLC and custom in-code migrations to the
|
||||||
|
// SQLite database.
|
||||||
|
func (s *SqliteStore) ApplyAllMigrations(ctx context.Context,
|
||||||
|
migrations []MigrationConfig) error {
|
||||||
|
|
||||||
|
// Execute migrations unless configured to skip them.
|
||||||
|
if s.cfg.SkipMigrations {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ApplyMigrations(ctx, s.BaseDB, s, migrations)
|
||||||
|
}
|
||||||
|
|
||||||
// ExecuteMigrations runs migrations for the sqlite database, depending on the
|
// ExecuteMigrations runs migrations for the sqlite database, depending on the
|
||||||
// target given, either all migrations or up to a given version.
|
// target given, either all migrations or up to a given version.
|
||||||
func (s *SqliteStore) ExecuteMigrations(target MigrationTarget) error {
|
func (s *SqliteStore) ExecuteMigrations(target MigrationTarget) error {
|
||||||
@ -160,6 +191,10 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore {
|
|||||||
}, dbFileName)
|
}, dbFileName)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, sqlDB.ApplyAllMigrations(
|
||||||
|
context.Background(), GetMigrations()),
|
||||||
|
)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, sqlDB.DB.Close())
|
require.NoError(t, sqlDB.DB.Close())
|
||||||
})
|
})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user