mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-04 08:50:20 +02:00
Merge pull request #8831 from bhandras/sql-invoice-migration
invoices: migrate KV invoices to native SQL for users of KV SQL backends
This commit is contained in:
commit
6cabc74c20
@ -51,6 +51,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/rpcperms"
|
||||
"github.com/lightningnetwork/lnd/signal"
|
||||
"github.com/lightningnetwork/lnd/sqldb"
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
"github.com/lightningnetwork/lnd/sweep"
|
||||
"github.com/lightningnetwork/lnd/walletunlocker"
|
||||
"github.com/lightningnetwork/lnd/watchtower"
|
||||
@ -60,6 +61,16 @@ import (
|
||||
"gopkg.in/macaroon-bakery.v2/bakery"
|
||||
)
|
||||
|
||||
const (
|
||||
// invoiceMigrationBatchSize is the number of invoices that will be
|
||||
// migrated in a single batch.
|
||||
invoiceMigrationBatchSize = 1000
|
||||
|
||||
// invoiceMigration is the version of the migration that will be used to
|
||||
// migrate invoices from the kvdb to the sql database.
|
||||
invoiceMigration = 7
|
||||
)
|
||||
|
||||
// GrpcRegistrar is an interface that must be satisfied by an external subserver
|
||||
// that wants to be able to register its own gRPC server onto lnd's main
|
||||
// grpc.Server instance.
|
||||
@ -932,10 +943,10 @@ type DatabaseInstances struct {
|
||||
// the btcwallet's loader.
|
||||
WalletDB btcwallet.LoaderOption
|
||||
|
||||
// NativeSQLStore is a pointer to a native SQL store that can be used
|
||||
// for native SQL queries for tables that already support it. This may
|
||||
// be nil if the use-native-sql flag was not set.
|
||||
NativeSQLStore *sqldb.BaseDB
|
||||
// NativeSQLStore holds a reference to the native SQL store that can
|
||||
// be used for native SQL queries for tables that already support it.
|
||||
// This may be nil if the use-native-sql flag was not set.
|
||||
NativeSQLStore sqldb.DB
|
||||
}
|
||||
|
||||
// DefaultDatabaseBuilder is a type that builds the default database backends
|
||||
@ -1038,7 +1049,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
|
||||
if err != nil {
|
||||
cleanUp()
|
||||
|
||||
err := fmt.Errorf("unable to open graph DB: %w", err)
|
||||
err = fmt.Errorf("unable to open graph DB: %w", err)
|
||||
d.logger.Error(err)
|
||||
|
||||
return nil, nil, err
|
||||
@ -1072,51 +1083,69 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
|
||||
case err != nil:
|
||||
cleanUp()
|
||||
|
||||
err := fmt.Errorf("unable to open graph DB: %w", err)
|
||||
err = fmt.Errorf("unable to open graph DB: %w", err)
|
||||
d.logger.Error(err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Instantiate a native SQL invoice store if the flag is set.
|
||||
// Instantiate a native SQL store if the flag is set.
|
||||
if d.cfg.DB.UseNativeSQL {
|
||||
// KV invoice db resides in the same database as the channel
|
||||
// state DB. Let's query the database to see if we have any
|
||||
// invoices there. If we do, we won't allow the user to start
|
||||
// lnd with native SQL enabled, as we don't currently migrate
|
||||
// the invoices to the new database schema.
|
||||
invoiceSlice, err := dbs.ChanStateDB.QueryInvoices(
|
||||
ctx, invoices.InvoiceQuery{
|
||||
NumMaxInvoices: 1,
|
||||
},
|
||||
migrations := sqldb.GetMigrations()
|
||||
|
||||
// If the user has not explicitly disabled the SQL invoice
|
||||
// migration, attach the custom migration function to invoice
|
||||
// migration (version 7). Even if this custom migration is
|
||||
// disabled, the regular native SQL store migrations will still
|
||||
// run. If the database version is already above this custom
|
||||
// migration's version (7), it will be skipped permanently,
|
||||
// regardless of the flag.
|
||||
if !d.cfg.DB.SkipSQLInvoiceMigration {
|
||||
migrationFn := func(tx *sqlc.Queries) error {
|
||||
return invoices.MigrateInvoicesToSQL(
|
||||
ctx, dbs.ChanStateDB.Backend,
|
||||
dbs.ChanStateDB, tx,
|
||||
invoiceMigrationBatchSize,
|
||||
)
|
||||
}
|
||||
|
||||
// Make sure we attach the custom migration function to
|
||||
// the correct migration version.
|
||||
for i := 0; i < len(migrations); i++ {
|
||||
if migrations[i].Version != invoiceMigration {
|
||||
continue
|
||||
}
|
||||
|
||||
migrations[i].MigrationFn = migrationFn
|
||||
}
|
||||
}
|
||||
|
||||
// We need to apply all migrations to the native SQL store
|
||||
// before we can use it.
|
||||
err = dbs.NativeSQLStore.ApplyAllMigrations(ctx, migrations)
|
||||
if err != nil {
|
||||
cleanUp()
|
||||
d.logger.Errorf("Unable to query KV invoice DB: %v",
|
||||
err)
|
||||
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(invoiceSlice.Invoices) > 0 {
|
||||
cleanUp()
|
||||
err := fmt.Errorf("found invoices in the KV invoice " +
|
||||
"DB, migration to native SQL is not yet " +
|
||||
"supported")
|
||||
err = fmt.Errorf("faild to run migrations for the "+
|
||||
"native SQL store: %w", err)
|
||||
d.logger.Error(err)
|
||||
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// With the DB ready and migrations applied, we can now create
|
||||
// the base DB and transaction executor for the native SQL
|
||||
// invoice store.
|
||||
baseDB := dbs.NativeSQLStore.GetBaseDB()
|
||||
executor := sqldb.NewTransactionExecutor(
|
||||
dbs.NativeSQLStore,
|
||||
func(tx *sql.Tx) invoices.SQLInvoiceQueries {
|
||||
return dbs.NativeSQLStore.WithTx(tx)
|
||||
baseDB, func(tx *sql.Tx) invoices.SQLInvoiceQueries {
|
||||
return baseDB.WithTx(tx)
|
||||
},
|
||||
)
|
||||
|
||||
dbs.InvoiceDB = invoices.NewSQLStore(
|
||||
sqlInvoiceDB := invoices.NewSQLStore(
|
||||
executor, clock.NewDefaultClock(),
|
||||
)
|
||||
|
||||
dbs.InvoiceDB = sqlInvoiceDB
|
||||
} else {
|
||||
dbs.InvoiceDB = dbs.ChanStateDB
|
||||
}
|
||||
@ -1129,7 +1158,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
|
||||
if err != nil {
|
||||
cleanUp()
|
||||
|
||||
err := fmt.Errorf("unable to open %s database: %w",
|
||||
err = fmt.Errorf("unable to open %s database: %w",
|
||||
lncfg.NSTowerClientDB, err)
|
||||
d.logger.Error(err)
|
||||
return nil, nil, err
|
||||
@ -1144,7 +1173,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
|
||||
if err != nil {
|
||||
cleanUp()
|
||||
|
||||
err := fmt.Errorf("unable to open %s database: %w",
|
||||
err = fmt.Errorf("unable to open %s database: %w",
|
||||
lncfg.NSTowerServerDB, err)
|
||||
d.logger.Error(err)
|
||||
return nil, nil, err
|
||||
|
@ -264,6 +264,11 @@ The underlying functionality between those two options remain the same.
|
||||
transactions can run at once, increasing efficiency. Includes several bugfixes
|
||||
to allow this to work properly.
|
||||
|
||||
* [Migrate KV invoices to
|
||||
SQL](https://github.com/lightningnetwork/lnd/pull/8831) as part of a larger
|
||||
effort to support SQL databases natively in LND.
|
||||
|
||||
|
||||
## Code Health
|
||||
|
||||
* A code refactor that [moves all the graph related DB code out of the
|
||||
@ -292,6 +297,7 @@ The underlying functionality between those two options remain the same.
|
||||
|
||||
* Abdullahi Yunus
|
||||
* Alex Akselrod
|
||||
* Andras Banki-Horvath
|
||||
* Animesh Bilthare
|
||||
* Boris Nagaev
|
||||
* Carla Kirk-Cohen
|
||||
|
6
go.mod
6
go.mod
@ -138,7 +138,7 @@ require (
|
||||
github.com/opencontainers/image-spec v1.0.2 // indirect
|
||||
github.com/opencontainers/runc v1.1.12 // indirect
|
||||
github.com/ory/dockertest/v3 v3.10.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0
|
||||
github.com/prometheus/client_model v0.2.0 // indirect
|
||||
github.com/prometheus/common v0.26.0 // indirect
|
||||
github.com/prometheus/procfs v0.6.0 // indirect
|
||||
@ -207,6 +207,10 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2
|
||||
// allows us to specify that as an option.
|
||||
replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display
|
||||
|
||||
// Temporary replace until https://github.com/lightningnetwork/lnd/pull/8831 is
|
||||
// merged.
|
||||
replace github.com/lightningnetwork/lnd/sqldb => ./sqldb
|
||||
|
||||
// If you change this please also update docs/INSTALL.md and GO_VERSION in
|
||||
// Makefile (then run `make lint` to see where else it needs to be updated as
|
||||
// well).
|
||||
|
2
go.sum
2
go.sum
@ -464,8 +464,6 @@ github.com/lightningnetwork/lnd/kvdb v1.4.12 h1:Y0WY5Tbjyjn6eCYh068qkWur5oFtioJl
|
||||
github.com/lightningnetwork/lnd/kvdb v1.4.12/go.mod h1:hx9buNcxsZpZwh8m1sjTQwy2SOeBoWWOZ3RnOQkMsxI=
|
||||
github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI=
|
||||
github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4=
|
||||
github.com/lightningnetwork/lnd/sqldb v1.0.6 h1:LJdDSVdN33bVBIefsaJlPW9PDAm6GrXlyFucmzSJ3Ts=
|
||||
github.com/lightningnetwork/lnd/sqldb v1.0.6/go.mod h1:OG09zL/PHPaBJefp4HsPz2YLUJ+zIQHbpgCtLnOx8I4=
|
||||
github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM=
|
||||
github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA=
|
||||
github.com/lightningnetwork/lnd/tlv v1.3.0 h1:exS/KCPEgpOgviIttfiXAPaUqw2rHQrnUOpP7HPBPiY=
|
||||
|
@ -187,6 +187,11 @@ func (r InvoiceRef) Modifier() RefModifier {
|
||||
return r.refModifier
|
||||
}
|
||||
|
||||
// IsHashOnly returns true if the invoice ref only contains a payment hash.
|
||||
func (r InvoiceRef) IsHashOnly() bool {
|
||||
return r.payHash != nil && r.payAddr == nil && r.setID == nil
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of an InvoiceRef.
|
||||
func (r InvoiceRef) String() string {
|
||||
var ids []string
|
||||
|
203
invoices/kv_sql_migration_test.go
Normal file
203
invoices/kv_sql_migration_test.go
Normal file
@ -0,0 +1,203 @@
|
||||
package invoices_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
invpkg "github.com/lightningnetwork/lnd/invoices"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/kvdb/sqlbase"
|
||||
"github.com/lightningnetwork/lnd/kvdb/sqlite"
|
||||
"github.com/lightningnetwork/lnd/lncfg"
|
||||
"github.com/lightningnetwork/lnd/sqldb"
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMigrationWithChannelDB tests the migration of invoices from a bolt backed
|
||||
// channel.db to a SQL database. Note that this test does not attempt to be a
|
||||
// complete migration test for all invoice types but rather is added as a tool
|
||||
// for developers and users to debug invoice migration issues with an actual
|
||||
// channel.db file.
|
||||
func TestMigrationWithChannelDB(t *testing.T) {
|
||||
// First create a shared Postgres instance so we don't spawn a new
|
||||
// docker container for each test.
|
||||
pgFixture := sqldb.NewTestPgFixture(
|
||||
t, sqldb.DefaultPostgresFixtureLifetime,
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
pgFixture.TearDown(t)
|
||||
})
|
||||
|
||||
makeSQLDB := func(t *testing.T, sqlite bool) (*invpkg.SQLStore,
|
||||
*sqldb.TransactionExecutor[*sqlc.Queries]) {
|
||||
|
||||
var db *sqldb.BaseDB
|
||||
if sqlite {
|
||||
db = sqldb.NewTestSqliteDB(t).BaseDB
|
||||
} else {
|
||||
db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
|
||||
}
|
||||
|
||||
invoiceExecutor := sqldb.NewTransactionExecutor(
|
||||
db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries {
|
||||
return db.WithTx(tx)
|
||||
},
|
||||
)
|
||||
|
||||
genericExecutor := sqldb.NewTransactionExecutor(
|
||||
db, func(tx *sql.Tx) *sqlc.Queries {
|
||||
return db.WithTx(tx)
|
||||
},
|
||||
)
|
||||
|
||||
testClock := clock.NewTestClock(time.Unix(1, 0))
|
||||
|
||||
return invpkg.NewSQLStore(invoiceExecutor, testClock),
|
||||
genericExecutor
|
||||
}
|
||||
|
||||
migrationTest := func(t *testing.T, kvStore *channeldb.DB,
|
||||
sqlite bool) {
|
||||
|
||||
sqlInvoiceStore, sqlStore := makeSQLDB(t, sqlite)
|
||||
ctxb := context.Background()
|
||||
|
||||
const batchSize = 11
|
||||
var opts sqldb.MigrationTxOptions
|
||||
err := sqlStore.ExecTx(
|
||||
ctxb, &opts, func(tx *sqlc.Queries) error {
|
||||
return invpkg.MigrateInvoicesToSQL(
|
||||
ctxb, kvStore.Backend, kvStore, tx,
|
||||
batchSize,
|
||||
)
|
||||
}, func() {},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// MigrateInvoices will check if the inserted invoice equals to
|
||||
// the migrated one, but as a sanity check, we'll also fetch the
|
||||
// invoices from the store and compare them to the original
|
||||
// invoices.
|
||||
query := invpkg.InvoiceQuery{
|
||||
IndexOffset: 0,
|
||||
// As a sanity check, fetch more invoices than we have
|
||||
// to ensure that we did not add any extra invoices.
|
||||
// Note that we don't really have a way to know the
|
||||
// exact number of invoices in the bolt db without first
|
||||
// iterating over all of them, but for test purposes
|
||||
// constant should be enough.
|
||||
NumMaxInvoices: 9999,
|
||||
}
|
||||
result1, err := kvStore.QueryInvoices(ctxb, query)
|
||||
require.NoError(t, err)
|
||||
numInvoices := len(result1.Invoices)
|
||||
|
||||
result2, err := sqlInvoiceStore.QueryInvoices(ctxb, query)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, numInvoices, len(result2.Invoices))
|
||||
|
||||
// Simply zero out the add index so we don't fail on that when
|
||||
// comparing.
|
||||
for i := 0; i < numInvoices; i++ {
|
||||
result1.Invoices[i].AddIndex = 0
|
||||
result2.Invoices[i].AddIndex = 0
|
||||
|
||||
// We need to override the timezone of the invoices as
|
||||
// the provided DB vs the test runners local time zone
|
||||
// might be different.
|
||||
invpkg.OverrideInvoiceTimeZone(&result1.Invoices[i])
|
||||
invpkg.OverrideInvoiceTimeZone(&result2.Invoices[i])
|
||||
|
||||
require.Equal(
|
||||
t, result1.Invoices[i], result2.Invoices[i],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dbPath string
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
t.TempDir(),
|
||||
},
|
||||
{
|
||||
"testdata",
|
||||
"testdata",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var kvStore *channeldb.DB
|
||||
|
||||
// First check if we have a channel.sqlite file in the
|
||||
// testdata directory. If we do, we'll use that as the
|
||||
// channel db for the migration test.
|
||||
chanDBPath := path.Join(
|
||||
test.dbPath, lncfg.SqliteChannelDBName,
|
||||
)
|
||||
|
||||
// Just some sane defaults for the sqlite config.
|
||||
const (
|
||||
timeout = 5 * time.Second
|
||||
maxConns = 50
|
||||
)
|
||||
|
||||
sqliteConfig := &sqlite.Config{
|
||||
Timeout: timeout,
|
||||
BusyTimeout: timeout,
|
||||
MaxConnections: maxConns,
|
||||
}
|
||||
|
||||
if fileExists(chanDBPath) {
|
||||
sqlbase.Init(maxConns)
|
||||
|
||||
sqliteBackend, err := kvdb.Open(
|
||||
kvdb.SqliteBackendName,
|
||||
context.Background(),
|
||||
sqliteConfig, test.dbPath,
|
||||
lncfg.SqliteChannelDBName,
|
||||
lncfg.NSChannelDB,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
kvStore, err = channeldb.CreateWithBackend(
|
||||
sqliteBackend,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
kvStore = channeldb.OpenForTesting(
|
||||
t, test.dbPath,
|
||||
)
|
||||
}
|
||||
|
||||
t.Run("Postgres", func(t *testing.T) {
|
||||
migrationTest(t, kvStore, false)
|
||||
})
|
||||
|
||||
t.Run("SQLite", func(t *testing.T) {
|
||||
migrationTest(t, kvStore, true)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func fileExists(filename string) bool {
|
||||
info, err := os.Stat(filename)
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
|
||||
return !info.IsDir()
|
||||
}
|
558
invoices/sql_migration.go
Normal file
558
invoices/sql_migration.go
Normal file
@ -0,0 +1,558 @@
|
||||
package invoices
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/graph/db/models"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/sqldb"
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
"github.com/pmezard/go-difflib/difflib"
|
||||
)
|
||||
|
||||
var (
|
||||
// invoiceBucket is the name of the bucket within the database that
|
||||
// stores all data related to invoices no matter their final state.
|
||||
// Within the invoice bucket, each invoice is keyed by its invoice ID
|
||||
// which is a monotonically increasing uint32.
|
||||
invoiceBucket = []byte("invoices")
|
||||
|
||||
// invoiceIndexBucket is the name of the sub-bucket within the
|
||||
// invoiceBucket which indexes all invoices by their payment hash. The
|
||||
// payment hash is the sha256 of the invoice's payment preimage. This
|
||||
// index is used to detect duplicates, and also to provide a fast path
|
||||
// for looking up incoming HTLCs to determine if we're able to settle
|
||||
// them fully.
|
||||
//
|
||||
// maps: payHash => invoiceKey
|
||||
invoiceIndexBucket = []byte("paymenthashes")
|
||||
|
||||
// numInvoicesKey is the name of key which houses the auto-incrementing
|
||||
// invoice ID which is essentially used as a primary key. With each
|
||||
// invoice inserted, the primary key is incremented by one. This key is
|
||||
// stored within the invoiceIndexBucket. Within the invoiceBucket
|
||||
// invoices are uniquely identified by the invoice ID.
|
||||
numInvoicesKey = []byte("nik")
|
||||
|
||||
// addIndexBucket is an index bucket that we'll use to create a
|
||||
// monotonically increasing set of add indexes. Each time we add a new
|
||||
// invoice, this sequence number will be incremented and then populated
|
||||
// within the new invoice.
|
||||
//
|
||||
// In addition to this sequence number, we map:
|
||||
//
|
||||
// addIndexNo => invoiceKey
|
||||
addIndexBucket = []byte("invoice-add-index")
|
||||
|
||||
// ErrMigrationMismatch is returned when the migrated invoice does not
|
||||
// match the original invoice.
|
||||
ErrMigrationMismatch = fmt.Errorf("migrated invoice does not match " +
|
||||
"original invoice")
|
||||
)
|
||||
|
||||
// createInvoiceHashIndex generates a hash index that contains payment hashes
|
||||
// for each invoice in the database. Retrieving the payment hash for certain
|
||||
// invoices, such as those created for spontaneous AMP payments, can be
|
||||
// challenging because the hash is not directly derivable from the invoice's
|
||||
// parameters and is stored separately in the `paymenthashes` bucket. This
|
||||
// bucket maps payment hashes to invoice keys, but for migration purposes, we
|
||||
// need the ability to query in the reverse direction. This function establishes
|
||||
// a new index in the SQL database that maps each invoice key to its
|
||||
// corresponding payment hash.
|
||||
func createInvoiceHashIndex(ctx context.Context, db kvdb.Backend,
|
||||
tx *sqlc.Queries) error {
|
||||
|
||||
return db.View(func(kvTx kvdb.RTx) error {
|
||||
invoices := kvTx.ReadBucket(invoiceBucket)
|
||||
if invoices == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
invoiceIndex := invoices.NestedReadBucket(
|
||||
invoiceIndexBucket,
|
||||
)
|
||||
if invoiceIndex == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
addIndex := invoices.NestedReadBucket(addIndexBucket)
|
||||
if addIndex == nil {
|
||||
return ErrNoInvoicesCreated
|
||||
}
|
||||
|
||||
// First, iterate over all elements in the add index bucket and
|
||||
// insert the add index value for the corresponding invoice key
|
||||
// in the payment_hashes table.
|
||||
err := addIndex.ForEach(func(k, v []byte) error {
|
||||
// The key is the add index, and the value is
|
||||
// the invoice key.
|
||||
addIndexNo := binary.BigEndian.Uint64(k)
|
||||
invoiceKey := binary.BigEndian.Uint32(v)
|
||||
|
||||
return tx.InsertKVInvoiceKeyAndAddIndex(ctx,
|
||||
sqlc.InsertKVInvoiceKeyAndAddIndexParams{
|
||||
ID: int64(invoiceKey),
|
||||
AddIndex: int64(addIndexNo),
|
||||
},
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Next, iterate over all hashes in the invoice index bucket and
|
||||
// set the hash to the corresponding the invoice key in the
|
||||
// payment_hashes table.
|
||||
return invoiceIndex.ForEach(func(k, v []byte) error {
|
||||
// Skip the special numInvoicesKey as that does
|
||||
// not point to a valid invoice.
|
||||
if bytes.Equal(k, numInvoicesKey) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// The key is the payment hash, and the value
|
||||
// is the invoice key.
|
||||
if len(k) != lntypes.HashSize {
|
||||
return fmt.Errorf("invalid payment "+
|
||||
"hash length: expected %v, "+
|
||||
"got %v", lntypes.HashSize,
|
||||
len(k))
|
||||
}
|
||||
|
||||
invoiceKey := binary.BigEndian.Uint32(v)
|
||||
|
||||
return tx.SetKVInvoicePaymentHash(ctx,
|
||||
sqlc.SetKVInvoicePaymentHashParams{
|
||||
ID: int64(invoiceKey),
|
||||
Hash: k,
|
||||
},
|
||||
)
|
||||
})
|
||||
}, func() {})
|
||||
}
|
||||
|
||||
// toInsertMigratedInvoiceParams creates the parameters for inserting a migrated
|
||||
// invoice into the SQL database. The parameters are derived from the original
|
||||
// invoice insert parameters.
|
||||
func toInsertMigratedInvoiceParams(
|
||||
params sqlc.InsertInvoiceParams) sqlc.InsertMigratedInvoiceParams {
|
||||
|
||||
return sqlc.InsertMigratedInvoiceParams{
|
||||
Hash: params.Hash,
|
||||
Preimage: params.Preimage,
|
||||
Memo: params.Memo,
|
||||
AmountMsat: params.AmountMsat,
|
||||
CltvDelta: params.CltvDelta,
|
||||
Expiry: params.Expiry,
|
||||
PaymentAddr: params.PaymentAddr,
|
||||
PaymentRequest: params.PaymentRequest,
|
||||
PaymentRequestHash: params.PaymentRequestHash,
|
||||
State: params.State,
|
||||
AmountPaidMsat: params.AmountPaidMsat,
|
||||
IsAmp: params.IsAmp,
|
||||
IsHodl: params.IsHodl,
|
||||
IsKeysend: params.IsKeysend,
|
||||
CreatedAt: params.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// MigrateSingleInvoice migrates a single invoice to the new SQL schema. Note
|
||||
// that perfect equality between the old and new schemas is not achievable, as
|
||||
// the invoice's add index cannot be mapped directly to its ID due to SQL’s
|
||||
// auto-incrementing primary key. The ID returned from the insert will instead
|
||||
// serve as the add index in the new schema.
|
||||
func MigrateSingleInvoice(ctx context.Context, tx SQLInvoiceQueries,
|
||||
invoice *Invoice, paymentHash lntypes.Hash) error {
|
||||
|
||||
insertInvoiceParams, err := makeInsertInvoiceParams(
|
||||
invoice, paymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert the insert invoice parameters to the migrated invoice insert
|
||||
// parameters.
|
||||
insertMigratedInvoiceParams := toInsertMigratedInvoiceParams(
|
||||
insertInvoiceParams,
|
||||
)
|
||||
|
||||
// If the invoice is settled, we'll also set the timestamp and the index
|
||||
// at which it was settled.
|
||||
if invoice.State == ContractSettled {
|
||||
if invoice.SettleIndex == 0 {
|
||||
return fmt.Errorf("settled invoice %s missing settle "+
|
||||
"index", paymentHash)
|
||||
}
|
||||
|
||||
if invoice.SettleDate.IsZero() {
|
||||
return fmt.Errorf("settled invoice %s missing settle "+
|
||||
"date", paymentHash)
|
||||
}
|
||||
|
||||
insertMigratedInvoiceParams.SettleIndex = sqldb.SQLInt64(
|
||||
invoice.SettleIndex,
|
||||
)
|
||||
insertMigratedInvoiceParams.SettledAt = sqldb.SQLTime(
|
||||
invoice.SettleDate.UTC(),
|
||||
)
|
||||
}
|
||||
|
||||
// First we need to insert the invoice itself so we can use the "add
|
||||
// index" which in this case is the auto incrementing primary key that
|
||||
// is returned from the insert.
|
||||
invoiceID, err := tx.InsertMigratedInvoice(
|
||||
ctx, insertMigratedInvoiceParams,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert invoice: %w", err)
|
||||
}
|
||||
|
||||
// Insert the invoice's features.
|
||||
for feature := range invoice.Terms.Features.Features() {
|
||||
params := sqlc.InsertInvoiceFeatureParams{
|
||||
InvoiceID: invoiceID,
|
||||
Feature: int32(feature),
|
||||
}
|
||||
|
||||
err := tx.InsertInvoiceFeature(ctx, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert invoice "+
|
||||
"feature(%v): %w", feature, err)
|
||||
}
|
||||
}
|
||||
|
||||
sqlHtlcIDs := make(map[models.CircuitKey]int64)
|
||||
|
||||
// Now insert the HTLCs of the invoice. We'll also keep track of the SQL
|
||||
// ID of each HTLC so we can use it when inserting the AMP sub invoices.
|
||||
for circuitKey, htlc := range invoice.Htlcs {
|
||||
htlcParams := sqlc.InsertInvoiceHTLCParams{
|
||||
HtlcID: int64(circuitKey.HtlcID),
|
||||
ChanID: strconv.FormatUint(
|
||||
circuitKey.ChanID.ToUint64(), 10,
|
||||
),
|
||||
AmountMsat: int64(htlc.Amt),
|
||||
AcceptHeight: int32(htlc.AcceptHeight),
|
||||
AcceptTime: htlc.AcceptTime.UTC(),
|
||||
ExpiryHeight: int32(htlc.Expiry),
|
||||
State: int16(htlc.State),
|
||||
InvoiceID: invoiceID,
|
||||
}
|
||||
|
||||
// Leave the MPP amount as NULL if the MPP total amount is zero.
|
||||
if htlc.MppTotalAmt != 0 {
|
||||
htlcParams.TotalMppMsat = sqldb.SQLInt64(
|
||||
int64(htlc.MppTotalAmt),
|
||||
)
|
||||
}
|
||||
|
||||
// Leave the resolve time as NULL if the HTLC is not resolved.
|
||||
if !htlc.ResolveTime.IsZero() {
|
||||
htlcParams.ResolveTime = sqldb.SQLTime(
|
||||
htlc.ResolveTime.UTC(),
|
||||
)
|
||||
}
|
||||
|
||||
sqlID, err := tx.InsertInvoiceHTLC(ctx, htlcParams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert invoice htlc: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
sqlHtlcIDs[circuitKey] = sqlID
|
||||
|
||||
// Store custom records.
|
||||
for key, value := range htlc.CustomRecords {
|
||||
err = tx.InsertInvoiceHTLCCustomRecord(
|
||||
ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
|
||||
Key: int64(key),
|
||||
Value: value,
|
||||
HtlcID: sqlID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !invoice.IsAMP() {
|
||||
return nil
|
||||
}
|
||||
|
||||
for setID, ampState := range invoice.AMPState {
|
||||
// Find the earliest HTLC of the AMP invoice, which will
|
||||
// be used as the creation date of this sub invoice.
|
||||
var createdAt time.Time
|
||||
for circuitKey := range ampState.InvoiceKeys {
|
||||
htlc := invoice.Htlcs[circuitKey]
|
||||
if createdAt.IsZero() {
|
||||
createdAt = htlc.AcceptTime.UTC()
|
||||
continue
|
||||
}
|
||||
|
||||
if createdAt.After(htlc.AcceptTime) {
|
||||
createdAt = htlc.AcceptTime.UTC()
|
||||
}
|
||||
}
|
||||
|
||||
params := sqlc.InsertAMPSubInvoiceParams{
|
||||
SetID: setID[:],
|
||||
State: int16(ampState.State),
|
||||
CreatedAt: createdAt,
|
||||
InvoiceID: invoiceID,
|
||||
}
|
||||
|
||||
if ampState.SettleIndex != 0 {
|
||||
if ampState.SettleDate.IsZero() {
|
||||
return fmt.Errorf("settled AMP sub invoice %x "+
|
||||
"missing settle date", setID)
|
||||
}
|
||||
|
||||
params.SettledAt = sqldb.SQLTime(
|
||||
ampState.SettleDate.UTC(),
|
||||
)
|
||||
|
||||
params.SettleIndex = sqldb.SQLInt64(
|
||||
ampState.SettleIndex,
|
||||
)
|
||||
}
|
||||
|
||||
err := tx.InsertAMPSubInvoice(ctx, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert AMP sub invoice: "+
|
||||
"%w", err)
|
||||
}
|
||||
|
||||
// Now we can add the AMP HTLCs to the database.
|
||||
for circuitKey := range ampState.InvoiceKeys {
|
||||
htlc := invoice.Htlcs[circuitKey]
|
||||
rootShare := htlc.AMP.Record.RootShare()
|
||||
|
||||
sqlHtlcID, ok := sqlHtlcIDs[circuitKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing htlc for AMP htlc: "+
|
||||
"%v", circuitKey)
|
||||
}
|
||||
|
||||
params := sqlc.InsertAMPSubInvoiceHTLCParams{
|
||||
InvoiceID: invoiceID,
|
||||
SetID: setID[:],
|
||||
HtlcID: sqlHtlcID,
|
||||
RootShare: rootShare[:],
|
||||
ChildIndex: int64(htlc.AMP.Record.ChildIndex()),
|
||||
Hash: htlc.AMP.Hash[:],
|
||||
}
|
||||
|
||||
if htlc.AMP.Preimage != nil {
|
||||
params.Preimage = htlc.AMP.Preimage[:]
|
||||
}
|
||||
|
||||
err = tx.InsertAMPSubInvoiceHTLC(ctx, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert AMP sub "+
|
||||
"invoice: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OverrideInvoiceTimeZone overrides the time zone of the invoice to the local
|
||||
// time zone and chops off the nanosecond part for comparison. This is needed
|
||||
// because KV database stores times as-is which as an unwanted side effect would
|
||||
// fail migration due to time comparison expecting both the original and
|
||||
// migrated invoices to be in the same local time zone and in microsecond
|
||||
// precision. Note that PostgreSQL stores times in microsecond precision while
|
||||
// SQLite can store times in nanosecond precision if using TEXT storage class.
|
||||
func OverrideInvoiceTimeZone(invoice *Invoice) {
|
||||
fixTime := func(t time.Time) time.Time {
|
||||
return t.In(time.Local).Truncate(time.Microsecond)
|
||||
}
|
||||
|
||||
invoice.CreationDate = fixTime(invoice.CreationDate)
|
||||
|
||||
if !invoice.SettleDate.IsZero() {
|
||||
invoice.SettleDate = fixTime(invoice.SettleDate)
|
||||
}
|
||||
|
||||
if invoice.IsAMP() {
|
||||
for setID, ampState := range invoice.AMPState {
|
||||
if ampState.SettleDate.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
ampState.SettleDate = fixTime(ampState.SettleDate)
|
||||
invoice.AMPState[setID] = ampState
|
||||
}
|
||||
}
|
||||
|
||||
for _, htlc := range invoice.Htlcs {
|
||||
if !htlc.AcceptTime.IsZero() {
|
||||
htlc.AcceptTime = fixTime(htlc.AcceptTime)
|
||||
}
|
||||
|
||||
if !htlc.ResolveTime.IsZero() {
|
||||
htlc.ResolveTime = fixTime(htlc.ResolveTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MigrateInvoicesToSQL runs the migration of all invoices from the KV database
|
||||
// to the SQL database. The migration is done in a single transaction to ensure
|
||||
// that all invoices are migrated or none at all. This function can be run
|
||||
// multiple times without causing any issues as it will check if the migration
|
||||
// has already been performed.
|
||||
func MigrateInvoicesToSQL(ctx context.Context, db kvdb.Backend,
|
||||
kvStore InvoiceDB, tx *sqlc.Queries, batchSize int) error {
|
||||
|
||||
log.Infof("Starting migration of invoices from KV to SQL")
|
||||
|
||||
offset := uint64(0)
|
||||
t0 := time.Now()
|
||||
|
||||
// Create the hash index which we will use to look up invoice
|
||||
// payment hashes by their add index during migration.
|
||||
err := createInvoiceHashIndex(ctx, db, tx)
|
||||
if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
|
||||
log.Errorf("Unable to create invoice hash index: %v",
|
||||
err)
|
||||
|
||||
return err
|
||||
}
|
||||
log.Debugf("Created SQL invoice hash index in %v", time.Since(t0))
|
||||
|
||||
total := 0
|
||||
// Now we can start migrating the invoices. We'll do this in
|
||||
// batches to reduce memory usage.
|
||||
for {
|
||||
t0 = time.Now()
|
||||
query := InvoiceQuery{
|
||||
IndexOffset: offset,
|
||||
NumMaxInvoices: uint64(batchSize),
|
||||
}
|
||||
|
||||
queryResult, err := kvStore.QueryInvoices(ctx, query)
|
||||
if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
|
||||
return fmt.Errorf("unable to query invoices: "+
|
||||
"%w", err)
|
||||
}
|
||||
|
||||
if len(queryResult.Invoices) == 0 {
|
||||
log.Infof("All invoices migrated")
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
err = migrateInvoices(ctx, tx, queryResult.Invoices)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset = queryResult.LastIndexOffset
|
||||
total += len(queryResult.Invoices)
|
||||
log.Debugf("Migrated %d KV invoices to SQL in %v\n", total,
|
||||
time.Since(t0))
|
||||
}
|
||||
|
||||
// Clean up the hash index as it's no longer needed.
|
||||
err = tx.ClearKVInvoiceHashIndex(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to clear invoice hash "+
|
||||
"index: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Migration of %d invoices from KV to SQL completed", total)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateInvoices(ctx context.Context, tx *sqlc.Queries,
|
||||
invoices []Invoice) error {
|
||||
|
||||
for i, invoice := range invoices {
|
||||
var paymentHash lntypes.Hash
|
||||
if invoice.Terms.PaymentPreimage != nil {
|
||||
paymentHash = invoice.Terms.PaymentPreimage.Hash()
|
||||
} else {
|
||||
paymentHashBytes, err :=
|
||||
tx.GetKVInvoicePaymentHashByAddIndex(
|
||||
ctx, int64(invoice.AddIndex),
|
||||
)
|
||||
if err != nil {
|
||||
// This would be an unexpected inconsistency
|
||||
// in the kv database. We can't do much here
|
||||
// so we'll notify the user and continue.
|
||||
log.Warnf("Cannot migrate invoice, unable to "+
|
||||
"fetch payment hash (add_index=%v): %v",
|
||||
invoice.AddIndex, err)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
copy(paymentHash[:], paymentHashBytes)
|
||||
}
|
||||
|
||||
err := MigrateSingleInvoice(ctx, tx, &invoices[i], paymentHash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to migrate invoice(%v): %w",
|
||||
paymentHash, err)
|
||||
}
|
||||
|
||||
migratedInvoice, err := fetchInvoice(
|
||||
ctx, tx, InvoiceRefByHash(paymentHash),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch migrated "+
|
||||
"invoice(%v): %w", paymentHash, err)
|
||||
}
|
||||
|
||||
// Override the time zone for comparison. Note that we need to
|
||||
// override both invoices as the original invoice is coming from
|
||||
// KV database, it was stored as a binary serialized Go
|
||||
// time.Time value which has nanosecond precision but might have
|
||||
// been created in a different time zone. The migrated invoice
|
||||
// is stored in SQL in UTC and selected in the local time zone,
|
||||
// however in PostgreSQL it has microsecond precision while in
|
||||
// SQLite it has nanosecond precision if using TEXT storage
|
||||
// class.
|
||||
OverrideInvoiceTimeZone(&invoice)
|
||||
OverrideInvoiceTimeZone(migratedInvoice)
|
||||
|
||||
// Override the add index before checking for equality.
|
||||
migratedInvoice.AddIndex = invoice.AddIndex
|
||||
|
||||
if !reflect.DeepEqual(invoice, *migratedInvoice) {
|
||||
diff := difflib.UnifiedDiff{
|
||||
A: difflib.SplitLines(
|
||||
spew.Sdump(invoice),
|
||||
),
|
||||
B: difflib.SplitLines(
|
||||
spew.Sdump(migratedInvoice),
|
||||
),
|
||||
FromFile: "Expected",
|
||||
FromDate: "",
|
||||
ToFile: "Actual",
|
||||
ToDate: "",
|
||||
Context: 3,
|
||||
}
|
||||
diffText, _ := difflib.GetUnifiedDiffString(diff)
|
||||
|
||||
return fmt.Errorf("%w: %v.\n%v", ErrMigrationMismatch,
|
||||
paymentHash, diffText)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
421
invoices/sql_migration_test.go
Normal file
421
invoices/sql_migration_test.go
Normal file
@ -0,0 +1,421 @@
|
||||
package invoices
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"database/sql"
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/graph/db/models"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/sqldb"
|
||||
"github.com/stretchr/testify/require"
|
||||
"pgregory.net/rapid"
|
||||
)
|
||||
|
||||
var (
|
||||
// testHtlcIDSequence is a global counter for generating unique HTLC
|
||||
// IDs.
|
||||
testHtlcIDSequence uint64
|
||||
)
|
||||
|
||||
// randomString generates a random string of a given length using rapid.
|
||||
func randomStringRapid(t *rapid.T, length int) string {
|
||||
// Define the character set for the string.
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" //nolint:ll
|
||||
|
||||
// Generate a string by selecting random characters from the charset.
|
||||
runes := make([]rune, length)
|
||||
for i := range runes {
|
||||
// Draw a random index and use it to select a character from the
|
||||
// charset.
|
||||
index := rapid.IntRange(0, len(charset)-1).Draw(t, "charIndex")
|
||||
runes[i] = rune(charset[index])
|
||||
}
|
||||
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
// randTimeBetween generates a random time between min and max.
|
||||
func randTimeBetween(min, max time.Time) time.Time {
|
||||
var timeZones = []*time.Location{
|
||||
time.UTC,
|
||||
time.FixedZone("EST", -5*3600),
|
||||
time.FixedZone("MST", -7*3600),
|
||||
time.FixedZone("PST", -8*3600),
|
||||
time.FixedZone("CEST", 2*3600),
|
||||
}
|
||||
|
||||
// Ensure max is after min
|
||||
if max.Before(min) {
|
||||
min, max = max, min
|
||||
}
|
||||
|
||||
// Calculate the range in nanoseconds
|
||||
duration := max.Sub(min)
|
||||
randDuration := time.Duration(rand.Int63n(duration.Nanoseconds()))
|
||||
|
||||
// Generate the random time
|
||||
randomTime := min.Add(randDuration)
|
||||
|
||||
// Assign a random time zone
|
||||
randomTimeZone := timeZones[rand.Intn(len(timeZones))]
|
||||
|
||||
// Return the time in the random time zone
|
||||
return randomTime.In(randomTimeZone)
|
||||
}
|
||||
|
||||
// randTime generates a random time between 2009 and 2140.
|
||||
func randTime() time.Time {
|
||||
min := time.Date(2009, 1, 3, 0, 0, 0, 0, time.UTC)
|
||||
max := time.Date(2140, 1, 1, 0, 0, 0, 1000, time.UTC)
|
||||
|
||||
return randTimeBetween(min, max)
|
||||
}
|
||||
|
||||
func randInvoiceTime(invoice *Invoice) time.Time {
|
||||
return randTimeBetween(
|
||||
invoice.CreationDate,
|
||||
invoice.CreationDate.Add(invoice.Terms.Expiry),
|
||||
)
|
||||
}
|
||||
|
||||
// randHTLCRapid generates a random HTLC for an invoice using rapid to randomize
|
||||
// its parameters.
|
||||
func randHTLCRapid(t *rapid.T, invoice *Invoice, amt lnwire.MilliSatoshi) (
|
||||
models.CircuitKey, *InvoiceHTLC) {
|
||||
|
||||
htlc := &InvoiceHTLC{
|
||||
Amt: amt,
|
||||
AcceptHeight: rapid.Uint32Range(1, 999).Draw(t, "AcceptHeight"),
|
||||
AcceptTime: randInvoiceTime(invoice),
|
||||
Expiry: rapid.Uint32Range(1, 999).Draw(t, "Expiry"),
|
||||
}
|
||||
|
||||
// Set MPP total amount if MPP feature is enabled in the invoice.
|
||||
if invoice.Terms.Features.HasFeature(lnwire.MPPRequired) {
|
||||
htlc.MppTotalAmt = invoice.Terms.Value
|
||||
}
|
||||
|
||||
// Set the HTLC state and resolve time based on the invoice state.
|
||||
switch invoice.State {
|
||||
case ContractSettled:
|
||||
htlc.State = HtlcStateSettled
|
||||
htlc.ResolveTime = randInvoiceTime(invoice)
|
||||
|
||||
case ContractCanceled:
|
||||
htlc.State = HtlcStateCanceled
|
||||
htlc.ResolveTime = randInvoiceTime(invoice)
|
||||
|
||||
case ContractAccepted:
|
||||
htlc.State = HtlcStateAccepted
|
||||
}
|
||||
|
||||
// Add randomized custom records to the HTLC.
|
||||
htlc.CustomRecords = make(record.CustomSet)
|
||||
numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords")
|
||||
for i := 0; i < numRecords; i++ {
|
||||
key := rapid.Uint64Range(
|
||||
record.CustomTypeStart, 1000+record.CustomTypeStart,
|
||||
).Draw(t, "customRecordKey")
|
||||
value := []byte(randomStringRapid(t, 10))
|
||||
htlc.CustomRecords[key] = value
|
||||
}
|
||||
|
||||
// Generate a unique HTLC ID and assign it to a channel ID.
|
||||
htlcID := atomic.AddUint64(&testHtlcIDSequence, 1)
|
||||
randChanID := lnwire.NewShortChanIDFromInt(htlcID % 5)
|
||||
|
||||
circuitKey := models.CircuitKey{
|
||||
ChanID: randChanID,
|
||||
HtlcID: htlcID,
|
||||
}
|
||||
|
||||
return circuitKey, htlc
|
||||
}
|
||||
|
||||
// generateInvoiceHTLCsRapid generates all HTLCs for an invoice, including AMP
|
||||
// HTLCs if applicable, using rapid for randomization of HTLC count and
|
||||
// distribution.
|
||||
func generateInvoiceHTLCsRapid(t *rapid.T, invoice *Invoice) {
|
||||
mpp := invoice.Terms.Features.HasFeature(lnwire.MPPRequired)
|
||||
|
||||
// Use rapid to determine the number of HTLCs based on invoice state and
|
||||
// MPP feature.
|
||||
numHTLCs := 1
|
||||
if invoice.State == ContractOpen {
|
||||
numHTLCs = 0
|
||||
} else if mpp {
|
||||
numHTLCs = rapid.IntRange(1, 10).Draw(t, "numHTLCs")
|
||||
}
|
||||
|
||||
total := invoice.Terms.Value
|
||||
|
||||
// Distribute the total amount across the HTLCs, adding any remainder to
|
||||
// the last HTLC.
|
||||
if numHTLCs > 0 {
|
||||
amt := total / lnwire.MilliSatoshi(numHTLCs)
|
||||
remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
|
||||
|
||||
for i := 0; i < numHTLCs; i++ {
|
||||
if i == numHTLCs-1 {
|
||||
// Add remainder to the last HTLC.
|
||||
amt += remainder
|
||||
}
|
||||
|
||||
// Generate an HTLC with a random circuit key and add it
|
||||
// to the invoice.
|
||||
circuitKey, htlc := randHTLCRapid(t, invoice, amt)
|
||||
invoice.Htlcs[circuitKey] = htlc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generateAMPHtlcsRapid generates AMP HTLCs for an invoice using rapid to
|
||||
// randomize various parameters of the HTLCs in the AMP set.
|
||||
func generateAMPHtlcsRapid(t *rapid.T, invoice *Invoice) {
|
||||
// Randomly determine the number of AMP sets (1 to 5).
|
||||
numSetIDs := rapid.IntRange(1, 5).Draw(t, "numSetIDs")
|
||||
settledIdx := uint64(1)
|
||||
|
||||
for i := 0; i < numSetIDs; i++ {
|
||||
var setID SetID
|
||||
_, err := crand.Read(setID[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Determine the number of HTLCs in this set (1 to 5).
|
||||
numHTLCs := rapid.IntRange(1, 5).Draw(t, "numHTLCs")
|
||||
total := invoice.Terms.Value
|
||||
invoiceKeys := make(map[CircuitKey]struct{})
|
||||
|
||||
// Calculate the amount per HTLC and account for remainder in
|
||||
// the final HTLC.
|
||||
amt := total / lnwire.MilliSatoshi(numHTLCs)
|
||||
remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
|
||||
|
||||
var htlcState HtlcState
|
||||
for j := 0; j < numHTLCs; j++ {
|
||||
if j == numHTLCs-1 {
|
||||
amt += remainder
|
||||
}
|
||||
|
||||
// Generate HTLC with randomized parameters.
|
||||
circuitKey, htlc := randHTLCRapid(t, invoice, amt)
|
||||
htlcState = htlc.State
|
||||
|
||||
var (
|
||||
rootShare, hash [32]byte
|
||||
preimage lntypes.Preimage
|
||||
)
|
||||
|
||||
// Randomize AMP data fields.
|
||||
_, err := crand.Read(rootShare[:])
|
||||
require.NoError(t, err)
|
||||
_, err = crand.Read(hash[:])
|
||||
require.NoError(t, err)
|
||||
_, err = crand.Read(preimage[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
record := record.NewAMP(rootShare, setID, uint32(j))
|
||||
|
||||
htlc.AMP = &InvoiceHtlcAMPData{
|
||||
Record: *record,
|
||||
Hash: hash,
|
||||
Preimage: &preimage,
|
||||
}
|
||||
|
||||
invoice.Htlcs[circuitKey] = htlc
|
||||
invoiceKeys[circuitKey] = struct{}{}
|
||||
}
|
||||
|
||||
ampState := InvoiceStateAMP{
|
||||
State: htlcState,
|
||||
InvoiceKeys: invoiceKeys,
|
||||
}
|
||||
if htlcState == HtlcStateSettled {
|
||||
ampState.SettleIndex = settledIdx
|
||||
ampState.SettleDate = randInvoiceTime(invoice)
|
||||
settledIdx++
|
||||
}
|
||||
|
||||
// Set the total amount paid if the AMP set is not canceled.
|
||||
if htlcState != HtlcStateCanceled {
|
||||
ampState.AmtPaid = invoice.Terms.Value
|
||||
}
|
||||
|
||||
invoice.AMPState[setID] = ampState
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrateSingleInvoiceRapid tests the migration of single invoices with
|
||||
// random data variations using rapid. This test generates a random invoice
|
||||
// configuration and ensures successful migration.
|
||||
//
|
||||
// NOTE: This test may need to be changed if the Invoice or any of the related
|
||||
// types are modified.
|
||||
func TestMigrateSingleInvoiceRapid(t *testing.T) {
|
||||
// Create a shared Postgres instance for efficient testing.
|
||||
pgFixture := sqldb.NewTestPgFixture(
|
||||
t, sqldb.DefaultPostgresFixtureLifetime,
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
pgFixture.TearDown(t)
|
||||
})
|
||||
|
||||
makeSQLDB := func(t *testing.T, sqlite bool) *SQLStore {
|
||||
var db *sqldb.BaseDB
|
||||
if sqlite {
|
||||
db = sqldb.NewTestSqliteDB(t).BaseDB
|
||||
} else {
|
||||
db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
|
||||
}
|
||||
|
||||
executor := sqldb.NewTransactionExecutor(
|
||||
db, func(tx *sql.Tx) SQLInvoiceQueries {
|
||||
return db.WithTx(tx)
|
||||
},
|
||||
)
|
||||
|
||||
testClock := clock.NewTestClock(time.Unix(1, 0))
|
||||
|
||||
return NewSQLStore(executor, testClock)
|
||||
}
|
||||
|
||||
// Define property-based test using rapid.
|
||||
rapid.Check(t, func(rt *rapid.T) {
|
||||
// Randomized feature flags for MPP and AMP.
|
||||
mpp := rapid.Bool().Draw(rt, "mpp")
|
||||
amp := rapid.Bool().Draw(rt, "amp")
|
||||
|
||||
for _, sqlite := range []bool{true, false} {
|
||||
store := makeSQLDB(t, sqlite)
|
||||
testMigrateSingleInvoiceRapid(rt, store, mpp, amp)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// testMigrateSingleInvoiceRapid is the primary function for the migration of a
|
||||
// single invoice with random data in a rapid-based test setup.
|
||||
func testMigrateSingleInvoiceRapid(t *rapid.T, store *SQLStore, mpp bool,
|
||||
amp bool) {
|
||||
|
||||
ctxb := context.Background()
|
||||
invoices := make(map[lntypes.Hash]*Invoice)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
invoice := generateTestInvoiceRapid(t, mpp, amp)
|
||||
var hash lntypes.Hash
|
||||
_, err := crand.Read(hash[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
invoices[hash] = invoice
|
||||
}
|
||||
|
||||
var ops SQLInvoiceQueriesTxOptions
|
||||
err := store.db.ExecTx(ctxb, &ops, func(tx SQLInvoiceQueries) error {
|
||||
for hash, invoice := range invoices {
|
||||
err := MigrateSingleInvoice(ctxb, tx, invoice, hash)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch and compare each migrated invoice from the store with the
|
||||
// original.
|
||||
for hash, invoice := range invoices {
|
||||
sqlInvoice, err := store.LookupInvoice(
|
||||
ctxb, InvoiceRefByHash(hash),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
invoice.AddIndex = sqlInvoice.AddIndex
|
||||
|
||||
OverrideInvoiceTimeZone(invoice)
|
||||
OverrideInvoiceTimeZone(&sqlInvoice)
|
||||
|
||||
require.Equal(t, *invoice, sqlInvoice)
|
||||
}
|
||||
}
|
||||
|
||||
// generateTestInvoiceRapid generates a random invoice with variations based on
|
||||
// mpp and amp flags.
|
||||
func generateTestInvoiceRapid(t *rapid.T, mpp bool, amp bool) *Invoice {
|
||||
var preimage lntypes.Preimage
|
||||
_, err := crand.Read(preimage[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
terms := ContractTerm{
|
||||
FinalCltvDelta: rapid.Int32Range(1, 1000).Draw(
|
||||
t, "FinalCltvDelta",
|
||||
),
|
||||
Expiry: time.Duration(
|
||||
rapid.IntRange(1, 4444).Draw(t, "Expiry"),
|
||||
) * time.Minute,
|
||||
PaymentPreimage: &preimage,
|
||||
Value: lnwire.MilliSatoshi(
|
||||
rapid.Int64Range(1, 9999999).Draw(t, "Value"),
|
||||
),
|
||||
PaymentAddr: [32]byte{},
|
||||
Features: lnwire.EmptyFeatureVector(),
|
||||
}
|
||||
|
||||
if amp {
|
||||
terms.Features.Set(lnwire.AMPRequired)
|
||||
} else if mpp {
|
||||
terms.Features.Set(lnwire.MPPRequired)
|
||||
}
|
||||
|
||||
created := randTime()
|
||||
|
||||
const maxContractState = 3
|
||||
state := ContractState(
|
||||
rapid.IntRange(0, maxContractState).Draw(t, "ContractState"),
|
||||
)
|
||||
var (
|
||||
settled time.Time
|
||||
settleIndex uint64
|
||||
)
|
||||
if state == ContractSettled {
|
||||
settled = randTimeBetween(created, created.Add(terms.Expiry))
|
||||
settleIndex = rapid.Uint64Range(1, 999).Draw(t, "SettleIndex")
|
||||
}
|
||||
|
||||
invoice := &Invoice{
|
||||
Memo: []byte(randomStringRapid(t, 10)),
|
||||
PaymentRequest: []byte(
|
||||
randomStringRapid(t, MaxPaymentRequestSize),
|
||||
),
|
||||
CreationDate: created,
|
||||
SettleDate: settled,
|
||||
Terms: terms,
|
||||
AddIndex: 0,
|
||||
SettleIndex: settleIndex,
|
||||
State: state,
|
||||
AMPState: make(map[SetID]InvoiceStateAMP),
|
||||
HodlInvoice: rapid.Bool().Draw(t, "HodlInvoice"),
|
||||
}
|
||||
|
||||
invoice.Htlcs = make(map[models.CircuitKey]*InvoiceHTLC)
|
||||
|
||||
if invoice.IsAMP() {
|
||||
generateAMPHtlcsRapid(t, invoice)
|
||||
} else {
|
||||
generateInvoiceHTLCsRapid(t, invoice)
|
||||
}
|
||||
|
||||
for _, htlc := range invoice.Htlcs {
|
||||
if htlc.State == HtlcStateSettled {
|
||||
invoice.AmtPaid += htlc.Amt
|
||||
}
|
||||
}
|
||||
|
||||
return invoice
|
||||
}
|
@ -32,6 +32,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
||||
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
|
||||
error)
|
||||
|
||||
// TODO(bhandras): remove this once migrations have been separated out.
|
||||
InsertMigratedInvoice(ctx context.Context,
|
||||
arg sqlc.InsertMigratedInvoiceParams) (int64, error)
|
||||
|
||||
InsertInvoiceFeature(ctx context.Context,
|
||||
arg sqlc.InsertInvoiceFeatureParams) error
|
||||
|
||||
@ -47,6 +51,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
||||
GetInvoice(ctx context.Context,
|
||||
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
|
||||
|
||||
GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
|
||||
error)
|
||||
|
||||
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
|
||||
error)
|
||||
|
||||
@ -79,6 +86,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
||||
UpsertAMPSubInvoice(ctx context.Context,
|
||||
arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
|
||||
|
||||
// TODO(bhandras): remove this once migrations have been separated out.
|
||||
InsertAMPSubInvoice(ctx context.Context,
|
||||
arg sqlc.InsertAMPSubInvoiceParams) error
|
||||
|
||||
UpdateAMPSubInvoiceState(ctx context.Context,
|
||||
arg sqlc.UpdateAMPSubInvoiceStateParams) error
|
||||
|
||||
@ -119,6 +130,19 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
||||
|
||||
OnAMPSubInvoiceSettled(ctx context.Context,
|
||||
arg sqlc.OnAMPSubInvoiceSettledParams) error
|
||||
|
||||
// Migration specific methods.
|
||||
// TODO(bhandras): remove this once migrations have been separated out.
|
||||
InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
|
||||
arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
|
||||
|
||||
SetKVInvoicePaymentHash(ctx context.Context,
|
||||
arg sqlc.SetKVInvoicePaymentHashParams) error
|
||||
|
||||
GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
|
||||
[]byte, error)
|
||||
|
||||
ClearKVInvoiceHashIndex(ctx context.Context) error
|
||||
}
|
||||
|
||||
var _ InvoiceDB = (*SQLStore)(nil)
|
||||
@ -200,6 +224,66 @@ func NewSQLStore(db BatchedSQLInvoiceQueries,
|
||||
}
|
||||
}
|
||||
|
||||
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
|
||||
sqlc.InsertInvoiceParams, error) {
|
||||
|
||||
// Precompute the payment request hash so we can use it in the query.
|
||||
var paymentRequestHash []byte
|
||||
if len(invoice.PaymentRequest) > 0 {
|
||||
h := sha256.New()
|
||||
h.Write(invoice.PaymentRequest)
|
||||
paymentRequestHash = h.Sum(nil)
|
||||
}
|
||||
|
||||
params := sqlc.InsertInvoiceParams{
|
||||
Hash: paymentHash[:],
|
||||
AmountMsat: int64(invoice.Terms.Value),
|
||||
CltvDelta: sqldb.SQLInt32(
|
||||
invoice.Terms.FinalCltvDelta,
|
||||
),
|
||||
Expiry: int32(invoice.Terms.Expiry.Seconds()),
|
||||
// Note: keysend invoices don't have a payment request.
|
||||
PaymentRequest: sqldb.SQLStr(string(
|
||||
invoice.PaymentRequest),
|
||||
),
|
||||
PaymentRequestHash: paymentRequestHash,
|
||||
State: int16(invoice.State),
|
||||
AmountPaidMsat: int64(invoice.AmtPaid),
|
||||
IsAmp: invoice.IsAMP(),
|
||||
IsHodl: invoice.HodlInvoice,
|
||||
IsKeysend: invoice.IsKeysend(),
|
||||
CreatedAt: invoice.CreationDate.UTC(),
|
||||
}
|
||||
|
||||
if invoice.Memo != nil {
|
||||
// Store the memo as a nullable string in the database. Note
|
||||
// that for compatibility reasons, we store the value as a valid
|
||||
// string even if it's empty.
|
||||
params.Memo = sql.NullString{
|
||||
String: string(invoice.Memo),
|
||||
Valid: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Some invoices may not have a preimage, like in the case of HODL
|
||||
// invoices.
|
||||
if invoice.Terms.PaymentPreimage != nil {
|
||||
preimage := *invoice.Terms.PaymentPreimage
|
||||
if preimage == UnknownPreimage {
|
||||
return sqlc.InsertInvoiceParams{},
|
||||
errors.New("cannot use all-zeroes preimage")
|
||||
}
|
||||
params.Preimage = preimage[:]
|
||||
}
|
||||
|
||||
// Some non MPP payments may have the default (invalid) value.
|
||||
if invoice.Terms.PaymentAddr != BlankPayAddr {
|
||||
params.PaymentAddr = invoice.Terms.PaymentAddr[:]
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// AddInvoice inserts the targeted invoice into the database. If the invoice has
|
||||
// *any* payment hashes which already exists within the database, then the
|
||||
// insertion will be aborted and rejected due to the strict policy banning any
|
||||
@ -220,55 +304,16 @@ func (i *SQLStore) AddInvoice(ctx context.Context,
|
||||
invoiceID int64
|
||||
)
|
||||
|
||||
// Precompute the payment request hash so we can use it in the query.
|
||||
var paymentRequestHash []byte
|
||||
if len(newInvoice.PaymentRequest) > 0 {
|
||||
h := sha256.New()
|
||||
h.Write(newInvoice.PaymentRequest)
|
||||
paymentRequestHash = h.Sum(nil)
|
||||
}
|
||||
|
||||
err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
|
||||
params := sqlc.InsertInvoiceParams{
|
||||
Hash: paymentHash[:],
|
||||
Memo: sqldb.SQLStr(string(newInvoice.Memo)),
|
||||
AmountMsat: int64(newInvoice.Terms.Value),
|
||||
// Note: BOLT12 invoices don't have a final cltv delta.
|
||||
CltvDelta: sqldb.SQLInt32(
|
||||
newInvoice.Terms.FinalCltvDelta,
|
||||
),
|
||||
Expiry: int32(newInvoice.Terms.Expiry.Seconds()),
|
||||
// Note: keysend invoices don't have a payment request.
|
||||
PaymentRequest: sqldb.SQLStr(string(
|
||||
newInvoice.PaymentRequest),
|
||||
),
|
||||
PaymentRequestHash: paymentRequestHash,
|
||||
State: int16(newInvoice.State),
|
||||
AmountPaidMsat: int64(newInvoice.AmtPaid),
|
||||
IsAmp: newInvoice.IsAMP(),
|
||||
IsHodl: newInvoice.HodlInvoice,
|
||||
IsKeysend: newInvoice.IsKeysend(),
|
||||
CreatedAt: newInvoice.CreationDate.UTC(),
|
||||
}
|
||||
|
||||
// Some invoices may not have a preimage, like in the case of
|
||||
// HODL invoices.
|
||||
if newInvoice.Terms.PaymentPreimage != nil {
|
||||
preimage := *newInvoice.Terms.PaymentPreimage
|
||||
if preimage == UnknownPreimage {
|
||||
return errors.New("cannot use all-zeroes " +
|
||||
"preimage")
|
||||
}
|
||||
params.Preimage = preimage[:]
|
||||
}
|
||||
|
||||
// Some non MPP payments may have the default (invalid) value.
|
||||
if newInvoice.Terms.PaymentAddr != BlankPayAddr {
|
||||
params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
|
||||
insertInvoiceParams, err := makeInsertInvoiceParams(
|
||||
newInvoice, paymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
|
||||
var err error
|
||||
invoiceID, err = db.InsertInvoice(ctx, params)
|
||||
invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert invoice: %w", err)
|
||||
}
|
||||
@ -312,22 +357,31 @@ func (i *SQLStore) AddInvoice(ctx context.Context,
|
||||
return newInvoice.AddIndex, nil
|
||||
}
|
||||
|
||||
// fetchInvoice fetches the common invoice data and the AMP state for the
|
||||
// invoice with the given reference.
|
||||
func (i *SQLStore) fetchInvoice(ctx context.Context,
|
||||
db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) {
|
||||
// getInvoiceByRef fetches the invoice with the given reference. The reference
|
||||
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
|
||||
func getInvoiceByRef(ctx context.Context,
|
||||
db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
|
||||
|
||||
// If the reference is empty, we can't look up the invoice.
|
||||
if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
|
||||
return nil, ErrInvoiceNotFound
|
||||
return sqlc.Invoice{}, ErrInvoiceNotFound
|
||||
}
|
||||
|
||||
var (
|
||||
invoice *Invoice
|
||||
params sqlc.GetInvoiceParams
|
||||
)
|
||||
// If the reference is a hash only, we can look up the invoice directly
|
||||
// by the payment hash which is faster.
|
||||
if ref.IsHashOnly() {
|
||||
invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return sqlc.Invoice{}, ErrInvoiceNotFound
|
||||
}
|
||||
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
// Otherwise the reference may include more fields, so we'll need to
|
||||
// assemble the query parameters based on the fields that are set.
|
||||
var params sqlc.GetInvoiceParams
|
||||
|
||||
// Given all invoices are uniquely identified by their payment hash,
|
||||
// we can use it to query a specific invoice.
|
||||
if ref.PayHash() != nil {
|
||||
params.Hash = ref.PayHash()[:]
|
||||
}
|
||||
@ -363,18 +417,34 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
|
||||
} else {
|
||||
rows, err = db.GetInvoice(ctx, params)
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(rows) == 0:
|
||||
return nil, ErrInvoiceNotFound
|
||||
return sqlc.Invoice{}, ErrInvoiceNotFound
|
||||
|
||||
case len(rows) > 1:
|
||||
// In case the reference is ambiguous, meaning it matches more
|
||||
// than one invoice, we'll return an error.
|
||||
return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
|
||||
ref.String(), spew.Sdump(rows))
|
||||
return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
|
||||
"%s: %s", ref.String(), spew.Sdump(rows))
|
||||
|
||||
case err != nil:
|
||||
return nil, fmt.Errorf("unable to fetch invoice: %w", err)
|
||||
return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
return rows[0], nil
|
||||
}
|
||||
|
||||
// fetchInvoice fetches the common invoice data and the AMP state for the
|
||||
// invoice with the given reference.
|
||||
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
|
||||
*Invoice, error) {
|
||||
|
||||
// Fetch the invoice from the database.
|
||||
sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
@ -391,8 +461,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
|
||||
fetchAmpHtlcs = true
|
||||
|
||||
case HtlcSetOnlyModifier:
|
||||
// In this case we'll fetch all AMP HTLCs for the
|
||||
// specified set id.
|
||||
// In this case we'll fetch all AMP HTLCs for the specified set
|
||||
// id.
|
||||
if ref.SetID() == nil {
|
||||
return nil, fmt.Errorf("set ID is required to use " +
|
||||
"the HTLC set only modifier")
|
||||
@ -412,8 +482,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
|
||||
}
|
||||
|
||||
// Fetch the rest of the invoice data and fill the invoice struct.
|
||||
_, invoice, err = fetchInvoiceData(
|
||||
ctx, db, rows[0], setID, fetchAmpHtlcs,
|
||||
_, invoice, err := fetchInvoiceData(
|
||||
ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -616,7 +686,7 @@ func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
|
||||
|
||||
invoiceKeys[key] = struct{}{}
|
||||
|
||||
if htlc.State != HtlcStateCanceled { //nolint: ll
|
||||
if htlc.State != HtlcStateCanceled {
|
||||
amtPaid += htlc.Amt
|
||||
}
|
||||
}
|
||||
@ -646,7 +716,7 @@ func (i *SQLStore) LookupInvoice(ctx context.Context,
|
||||
|
||||
readTxOpt := NewSQLInvoiceQueryReadTx()
|
||||
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
|
||||
invoice, err = i.fetchInvoice(ctx, db, ref)
|
||||
invoice, err = fetchInvoice(ctx, db, ref)
|
||||
|
||||
return err
|
||||
}, func() {})
|
||||
@ -1347,7 +1417,7 @@ func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
|
||||
ref.refModifier = HtlcSetOnlyModifier
|
||||
}
|
||||
|
||||
invoice, err := i.fetchInvoice(ctx, db, ref)
|
||||
invoice, err := fetchInvoice(ctx, db, ref)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1506,13 +1576,6 @@ func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
|
||||
|
||||
if len(htlcs) > 0 {
|
||||
invoice.Htlcs = htlcs
|
||||
var amountPaid lnwire.MilliSatoshi
|
||||
for _, htlc := range htlcs {
|
||||
if htlc.State == HtlcStateSettled {
|
||||
amountPaid += htlc.Amt
|
||||
}
|
||||
}
|
||||
invoice.AmtPaid = amountPaid
|
||||
}
|
||||
|
||||
return hash, invoice, nil
|
||||
|
BIN
invoices/testdata/channel.db
vendored
Normal file
BIN
invoices/testdata/channel.db
vendored
Normal file
Binary file not shown.
@ -626,10 +626,6 @@ var allTestCases = []*lntest.TestCase{
|
||||
Name: "open channel locked balance",
|
||||
TestFunc: testOpenChannelLockedBalance,
|
||||
},
|
||||
{
|
||||
Name: "nativesql no migration",
|
||||
TestFunc: testNativeSQLNoMigration,
|
||||
},
|
||||
{
|
||||
Name: "sweep cpfp anchor outgoing timeout",
|
||||
TestFunc: testSweepCPFPAnchorOutgoingTimeout,
|
||||
@ -682,6 +678,10 @@ var allTestCases = []*lntest.TestCase{
|
||||
Name: "quiescence",
|
||||
TestFunc: testQuiescence,
|
||||
},
|
||||
{
|
||||
Name: "invoice migration",
|
||||
TestFunc: testInvoiceMigration,
|
||||
},
|
||||
}
|
||||
|
||||
// appendPrefixed is used to add a prefix to each test name in the subtests
|
||||
|
303
itest/lnd_invoice_migration_test.go
Normal file
303
itest/lnd_invoice_migration_test.go
Normal file
@ -0,0 +1,303 @@
|
||||
package itest
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/invoices"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/kvdb/postgres"
|
||||
"github.com/lightningnetwork/lnd/kvdb/sqlbase"
|
||||
"github.com/lightningnetwork/lnd/kvdb/sqlite"
|
||||
"github.com/lightningnetwork/lnd/lncfg"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
|
||||
"github.com/lightningnetwork/lnd/lntest"
|
||||
"github.com/lightningnetwork/lnd/lntest/node"
|
||||
"github.com/lightningnetwork/lnd/sqldb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func openChannelDB(ht *lntest.HarnessTest, hn *node.HarnessNode) *channeldb.DB {
|
||||
sqlbase.Init(0)
|
||||
var (
|
||||
backend kvdb.Backend
|
||||
err error
|
||||
)
|
||||
|
||||
switch hn.Cfg.DBBackend {
|
||||
case node.BackendSqlite:
|
||||
backend, err = kvdb.Open(
|
||||
kvdb.SqliteBackendName,
|
||||
ht.Context(),
|
||||
&sqlite.Config{
|
||||
Timeout: defaultTimeout,
|
||||
BusyTimeout: defaultTimeout,
|
||||
},
|
||||
hn.Cfg.DBDir(), lncfg.SqliteChannelDBName,
|
||||
lncfg.NSChannelDB,
|
||||
)
|
||||
require.NoError(ht, err)
|
||||
|
||||
case node.BackendPostgres:
|
||||
backend, err = kvdb.Open(
|
||||
kvdb.PostgresBackendName, ht.Context(),
|
||||
&postgres.Config{
|
||||
Dsn: hn.Cfg.PostgresDsn,
|
||||
Timeout: defaultTimeout,
|
||||
}, lncfg.NSChannelDB,
|
||||
)
|
||||
require.NoError(ht, err)
|
||||
}
|
||||
|
||||
db, err := channeldb.CreateWithBackend(backend)
|
||||
require.NoError(ht, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func openNativeSQLInvoiceDB(ht *lntest.HarnessTest,
|
||||
hn *node.HarnessNode) invoices.InvoiceDB {
|
||||
|
||||
var db *sqldb.BaseDB
|
||||
|
||||
switch hn.Cfg.DBBackend {
|
||||
case node.BackendSqlite:
|
||||
sqliteStore, err := sqldb.NewSqliteStore(
|
||||
&sqldb.SqliteConfig{
|
||||
Timeout: defaultTimeout,
|
||||
BusyTimeout: defaultTimeout,
|
||||
},
|
||||
path.Join(
|
||||
hn.Cfg.DBDir(),
|
||||
lncfg.SqliteNativeDBName,
|
||||
),
|
||||
)
|
||||
require.NoError(ht, err)
|
||||
db = sqliteStore.BaseDB
|
||||
|
||||
case node.BackendPostgres:
|
||||
postgresStore, err := sqldb.NewPostgresStore(
|
||||
&sqldb.PostgresConfig{
|
||||
Dsn: hn.Cfg.PostgresDsn,
|
||||
Timeout: defaultTimeout,
|
||||
},
|
||||
)
|
||||
require.NoError(ht, err)
|
||||
db = postgresStore.BaseDB
|
||||
}
|
||||
|
||||
executor := sqldb.NewTransactionExecutor(
|
||||
db, func(tx *sql.Tx) invoices.SQLInvoiceQueries {
|
||||
return db.WithTx(tx)
|
||||
},
|
||||
)
|
||||
|
||||
return invoices.NewSQLStore(
|
||||
executor, clock.NewDefaultClock(),
|
||||
)
|
||||
}
|
||||
|
||||
// clampTime truncates the time of the passed invoice to the microsecond level.
|
||||
func clampTime(invoice *invoices.Invoice) {
|
||||
trunc := func(t time.Time) time.Time {
|
||||
return t.Truncate(time.Microsecond)
|
||||
}
|
||||
|
||||
invoice.CreationDate = trunc(invoice.CreationDate)
|
||||
|
||||
if !invoice.SettleDate.IsZero() {
|
||||
invoice.SettleDate = trunc(invoice.SettleDate)
|
||||
}
|
||||
|
||||
if invoice.IsAMP() {
|
||||
for setID, ampState := range invoice.AMPState {
|
||||
if ampState.SettleDate.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
ampState.SettleDate = trunc(ampState.SettleDate)
|
||||
invoice.AMPState[setID] = ampState
|
||||
}
|
||||
}
|
||||
|
||||
for _, htlc := range invoice.Htlcs {
|
||||
if !htlc.AcceptTime.IsZero() {
|
||||
htlc.AcceptTime = trunc(htlc.AcceptTime)
|
||||
}
|
||||
|
||||
if !htlc.ResolveTime.IsZero() {
|
||||
htlc.ResolveTime = trunc(htlc.ResolveTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testInvoiceMigration tests that the invoice migration from the old KV store
|
||||
// to the new native SQL store works as expected.
|
||||
func testInvoiceMigration(ht *lntest.HarnessTest) {
|
||||
alice := ht.NewNodeWithCoins("Alice", nil)
|
||||
bob := ht.NewNodeWithCoins("Bob", []string{"--accept-amp"})
|
||||
|
||||
// Make sure we run the test with SQLite or Postgres.
|
||||
if bob.Cfg.DBBackend != node.BackendSqlite &&
|
||||
bob.Cfg.DBBackend != node.BackendPostgres {
|
||||
|
||||
ht.Skip("node not running with SQLite or Postgres")
|
||||
}
|
||||
|
||||
// Skip the test if the node is already running with native SQL.
|
||||
if bob.Cfg.NativeSQL {
|
||||
ht.Skip("node already running with native SQL")
|
||||
}
|
||||
|
||||
ht.EnsureConnected(alice, bob)
|
||||
cp := ht.OpenChannel(
|
||||
alice, bob, lntest.OpenChannelParams{
|
||||
Amt: 1000000,
|
||||
PushAmt: 500000,
|
||||
},
|
||||
)
|
||||
|
||||
// Alice and bob should have one channel open with each other now.
|
||||
ht.AssertNodeNumChannels(alice, 1)
|
||||
ht.AssertNodeNumChannels(bob, 1)
|
||||
|
||||
// Step 1: Add 10 normal invoices and pay 5 of them.
|
||||
normalInvoices := make([]*lnrpc.AddInvoiceResponse, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
invoice := &lnrpc.Invoice{
|
||||
Value: int64(1000 + i*100), // Varying amounts
|
||||
IsAmp: false,
|
||||
}
|
||||
|
||||
resp := bob.RPC.AddInvoice(invoice)
|
||||
normalInvoices[i] = resp
|
||||
}
|
||||
|
||||
for _, inv := range normalInvoices {
|
||||
sendReq := &routerrpc.SendPaymentRequest{
|
||||
PaymentRequest: inv.PaymentRequest,
|
||||
TimeoutSeconds: 60,
|
||||
FeeLimitMsat: noFeeLimitMsat,
|
||||
}
|
||||
|
||||
ht.SendPaymentAssertSettled(alice, sendReq)
|
||||
}
|
||||
|
||||
// Step 2: Add 10 AMP invoices and send multiple payments to 5 of them.
|
||||
ampInvoices := make([]*lnrpc.AddInvoiceResponse, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
invoice := &lnrpc.Invoice{
|
||||
Value: int64(2000 + i*200), // Varying amounts
|
||||
IsAmp: true,
|
||||
}
|
||||
|
||||
resp := bob.RPC.AddInvoice(invoice)
|
||||
ampInvoices[i] = resp
|
||||
}
|
||||
|
||||
// Select the first 5 invoices to send multiple AMP payments.
|
||||
for i := 0; i < 5; i++ {
|
||||
inv := ampInvoices[i]
|
||||
|
||||
// Send 3 payments to each.
|
||||
for j := 0; j < 3; j++ {
|
||||
payReq := &routerrpc.SendPaymentRequest{
|
||||
PaymentRequest: inv.PaymentRequest,
|
||||
TimeoutSeconds: 60,
|
||||
FeeLimitMsat: noFeeLimitMsat,
|
||||
Amp: true,
|
||||
}
|
||||
|
||||
// Send a normal AMP payment first, then a spontaneous
|
||||
// AMP payment.
|
||||
ht.SendPaymentAssertSettled(alice, payReq)
|
||||
|
||||
// Generate an external payment address when attempting
|
||||
// to pseudo-reuse an AMP invoice. When using an
|
||||
// external payment address, we'll also expect an extra
|
||||
// invoice to appear in the ListInvoices response, since
|
||||
// a new invoice will be JIT inserted under a different
|
||||
// payment address than the one in the invoice.
|
||||
//
|
||||
// NOTE: This will only work when the peer has
|
||||
// spontaneous AMP payments enabled otherwise no invoice
|
||||
// under a different payment_addr will be found.
|
||||
payReq.PaymentAddr = ht.Random32Bytes()
|
||||
ht.SendPaymentAssertSettled(alice, payReq)
|
||||
}
|
||||
}
|
||||
|
||||
// We can close the channel now.
|
||||
ht.CloseChannel(alice, cp)
|
||||
|
||||
// Now stop Bob so we can open the DB for examination.
|
||||
require.NoError(ht, bob.Stop())
|
||||
|
||||
// Open the KV channel DB.
|
||||
db := openChannelDB(ht, bob)
|
||||
|
||||
query := invoices.InvoiceQuery{
|
||||
IndexOffset: 0,
|
||||
// As a sanity check, fetch more invoices than we have
|
||||
// to ensure that we did not add any extra invoices.
|
||||
NumMaxInvoices: 9999,
|
||||
}
|
||||
|
||||
// Fetch all invoices and make sure we have 35 in total.
|
||||
result1, err := db.QueryInvoices(ht.Context(), query)
|
||||
require.NoError(ht, err)
|
||||
require.Equal(ht, 35, len(result1.Invoices))
|
||||
numInvoices := len(result1.Invoices)
|
||||
|
||||
bob.SetExtraArgs([]string{"--db.use-native-sql"})
|
||||
|
||||
// Now run the migration flow three times to ensure that each run is
|
||||
// idempotent.
|
||||
for i := 0; i < 3; i++ {
|
||||
// Start bob with the native SQL flag set. This will trigger the
|
||||
// migration to run.
|
||||
require.NoError(ht, bob.Start(ht.Context()))
|
||||
|
||||
// At this point the migration should have completed and the
|
||||
// node should be running with native SQL. Now we'll stop Bob
|
||||
// again so we can safely examine the database.
|
||||
require.NoError(ht, bob.Stop())
|
||||
|
||||
// Now we'll open the database with the native SQL backend and
|
||||
// fetch the invoices again to ensure that they were migrated
|
||||
// correctly.
|
||||
sqlInvoiceDB := openNativeSQLInvoiceDB(ht, bob)
|
||||
result2, err := sqlInvoiceDB.QueryInvoices(ht.Context(), query)
|
||||
require.NoError(ht, err)
|
||||
|
||||
require.Equal(ht, numInvoices, len(result2.Invoices))
|
||||
|
||||
// Simply zero out the add index so we don't fail on that when
|
||||
// comparing.
|
||||
for i := 0; i < numInvoices; i++ {
|
||||
result1.Invoices[i].AddIndex = 0
|
||||
result2.Invoices[i].AddIndex = 0
|
||||
|
||||
// Clamp the precision to microseconds. Note that we
|
||||
// need to override both invoices as the original
|
||||
// invoice is coming from KV database, it was stored as
|
||||
// a binary serialized Go time.Time value which has
|
||||
// nanosecond precision. The migrated invoice is stored
|
||||
// in SQL in PostgreSQL has microsecond precision while
|
||||
// in SQLite it has nanosecond precision if using TEXT
|
||||
// storage class.
|
||||
clampTime(&result1.Invoices[i])
|
||||
clampTime(&result2.Invoices[i])
|
||||
require.Equal(
|
||||
ht, result1.Invoices[i], result2.Invoices[i],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Start Bob again so the test can complete.
|
||||
require.NoError(ht, bob.Start(ht.Context()))
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package itest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
@ -1243,44 +1242,6 @@ func testSignVerifyMessageWithAddr(ht *lntest.HarnessTest) {
|
||||
require.False(ht, respValid.Valid, "external signature did validate")
|
||||
}
|
||||
|
||||
// testNativeSQLNoMigration tests that nodes that have invoices would not start
|
||||
// up with native SQL enabled, as we don't currently support migration of KV
|
||||
// invoices to the new SQL schema.
|
||||
func testNativeSQLNoMigration(ht *lntest.HarnessTest) {
|
||||
alice := ht.NewNode("Alice", nil)
|
||||
|
||||
// Make sure we run the test with SQLite or Postgres.
|
||||
if alice.Cfg.DBBackend != node.BackendSqlite &&
|
||||
alice.Cfg.DBBackend != node.BackendPostgres {
|
||||
|
||||
ht.Skip("node not running with SQLite or Postgres")
|
||||
}
|
||||
|
||||
// Skip the test if the node is already running with native SQL.
|
||||
if alice.Cfg.NativeSQL {
|
||||
ht.Skip("node already running with native SQL")
|
||||
}
|
||||
|
||||
alice.RPC.AddInvoice(&lnrpc.Invoice{
|
||||
Value: 10_000,
|
||||
})
|
||||
|
||||
alice.SetExtraArgs([]string{"--db.use-native-sql"})
|
||||
|
||||
// Restart the node manually as we're really only interested in the
|
||||
// startup error.
|
||||
require.NoError(ht, alice.Stop())
|
||||
require.NoError(ht, alice.StartLndCmd(context.Background()))
|
||||
|
||||
// We expect the node to fail to start up with native SQL enabled, as we
|
||||
// have an invoice in the KV store.
|
||||
require.Error(ht, alice.WaitForProcessExit())
|
||||
|
||||
// Reset the extra args and restart alice.
|
||||
alice.SetExtraArgs(nil)
|
||||
require.NoError(ht, alice.Start(ht.Context()))
|
||||
}
|
||||
|
||||
// testSendSelectedCoins tests that we're able to properly send the selected
|
||||
// coins from the wallet to a single target address.
|
||||
func testSendSelectedCoins(ht *lntest.HarnessTest) {
|
||||
|
19
lncfg/db.go
19
lncfg/db.go
@ -87,6 +87,8 @@ type DB struct {
|
||||
|
||||
UseNativeSQL bool `long:"use-native-sql" description:"Use native SQL for tables that already support it."`
|
||||
|
||||
SkipSQLInvoiceMigration bool `long:"skip-sql-invoice-migration" description:"Do not migrate invoices stored in our key-value database to native SQL."`
|
||||
|
||||
NoGraphCache bool `long:"no-graph-cache" description:"Don't use the in-memory graph cache for path finding. Much slower but uses less RAM. Can only be used with a bolt database backend."`
|
||||
|
||||
PruneRevocation bool `long:"prune-revocation" description:"Run the optional migration that prunes the revocation logs to save disk space."`
|
||||
@ -116,6 +118,7 @@ func DefaultDB() *DB {
|
||||
BusyTimeout: defaultSqliteBusyTimeout,
|
||||
},
|
||||
UseNativeSQL: false,
|
||||
SkipSQLInvoiceMigration: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -231,10 +234,10 @@ type DatabaseBackends struct {
|
||||
// the underlying wallet database from.
|
||||
WalletDB btcwallet.LoaderOption
|
||||
|
||||
// NativeSQLStore is a pointer to a native SQL store that can be used
|
||||
// for native SQL queries for tables that already support it. This may
|
||||
// be nil if the use-native-sql flag was not set.
|
||||
NativeSQLStore *sqldb.BaseDB
|
||||
// NativeSQLStore holds a reference to the native SQL store that can
|
||||
// be used for native SQL queries for tables that already support it.
|
||||
// This may be nil if the use-native-sql flag was not set.
|
||||
NativeSQLStore sqldb.DB
|
||||
|
||||
// Remote indicates whether the database backends are remote, possibly
|
||||
// replicated instances or local bbolt or sqlite backed databases.
|
||||
@ -449,7 +452,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
||||
}
|
||||
closeFuncs[NSWalletDB] = postgresWalletBackend.Close
|
||||
|
||||
var nativeSQLStore *sqldb.BaseDB
|
||||
var nativeSQLStore sqldb.DB
|
||||
if db.UseNativeSQL {
|
||||
nativePostgresStore, err := sqldb.NewPostgresStore(
|
||||
db.Postgres,
|
||||
@ -459,7 +462,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
||||
"native postgres store: %v", err)
|
||||
}
|
||||
|
||||
nativeSQLStore = nativePostgresStore.BaseDB
|
||||
nativeSQLStore = nativePostgresStore
|
||||
closeFuncs[PostgresBackend] = nativePostgresStore.Close
|
||||
}
|
||||
|
||||
@ -571,7 +574,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
||||
}
|
||||
closeFuncs[NSWalletDB] = sqliteWalletBackend.Close
|
||||
|
||||
var nativeSQLStore *sqldb.BaseDB
|
||||
var nativeSQLStore sqldb.DB
|
||||
if db.UseNativeSQL {
|
||||
nativeSQLiteStore, err := sqldb.NewSqliteStore(
|
||||
db.Sqlite,
|
||||
@ -582,7 +585,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
|
||||
"native SQLite store: %v", err)
|
||||
}
|
||||
|
||||
nativeSQLStore = nativeSQLiteStore.BaseDB
|
||||
nativeSQLStore = nativeSQLiteStore
|
||||
closeFuncs[SqliteBackend] = nativeSQLiteStore.Close
|
||||
}
|
||||
|
||||
|
@ -1472,6 +1472,9 @@
|
||||
; own risk.
|
||||
; db.use-native-sql=false
|
||||
|
||||
; If set to true, native SQL invoice migration will be skipped. Note that this
|
||||
; option is intended for users who experience non-resolvable migration errors.
|
||||
; db.skip-sql-invoice-migration=false
|
||||
|
||||
[etcd]
|
||||
|
||||
|
@ -2,12 +2,40 @@
|
||||
|
||||
set -e
|
||||
|
||||
# restore_files is a function to restore original schema files.
|
||||
restore_files() {
|
||||
echo "Restoring SQLite bigint patch..."
|
||||
for file in sqldb/sqlc/migrations/*.up.sql.bak; do
|
||||
mv "$file" "${file%.bak}"
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
# Set trap to call restore_files on script exit. This makes sure the old files
|
||||
# are always restored.
|
||||
trap restore_files EXIT
|
||||
|
||||
# Directory of the script file, independent of where it's called from.
|
||||
DIR="$(cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd)"
|
||||
# Use the user's cache directories
|
||||
GOCACHE=`go env GOCACHE`
|
||||
GOMODCACHE=`go env GOMODCACHE`
|
||||
|
||||
# SQLite doesn't support "BIGINT PRIMARY KEY" for auto-incrementing primary
|
||||
# keys, only "INTEGER PRIMARY KEY". Internally it uses 64-bit integers for
|
||||
# numbers anyway, independent of the column type. So we can just use
|
||||
# "INTEGER PRIMARY KEY" and it will work the same under the hood, giving us
|
||||
# auto incrementing 64-bit integers.
|
||||
# _BUT_, sqlc will generate Go code with int32 if we use "INTEGER PRIMARY KEY",
|
||||
# even though we want int64. So before we run sqlc, we need to patch the
|
||||
# source schema SQL files to use "BIGINT PRIMARY KEY" instead of "INTEGER
|
||||
# PRIMARY KEY".
|
||||
echo "Applying SQLite bigint patch..."
|
||||
for file in sqldb/sqlc/migrations/*.up.sql; do
|
||||
echo "Patching $file"
|
||||
sed -i.bak -E 's/INTEGER PRIMARY KEY/BIGINT PRIMARY KEY/g' "$file"
|
||||
done
|
||||
|
||||
echo "Generating sql models and queries in go..."
|
||||
|
||||
docker run \
|
||||
|
@ -355,6 +355,18 @@ func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context,
|
||||
)
|
||||
}
|
||||
|
||||
// DB is an interface that represents a generic SQL database. It provides
|
||||
// methods to apply migrations and access the underlying database connection.
|
||||
type DB interface {
|
||||
// GetBaseDB returns the underlying BaseDB instance.
|
||||
GetBaseDB() *BaseDB
|
||||
|
||||
// ApplyAllMigrations applies all migrations to the database including
|
||||
// both sqlc and custom in-code migrations.
|
||||
ApplyAllMigrations(ctx context.Context,
|
||||
customMigrations []MigrationConfig) error
|
||||
}
|
||||
|
||||
// BaseDB is the base database struct that each implementation can embed to
|
||||
// gain some common functionality.
|
||||
type BaseDB struct {
|
||||
|
@ -2,22 +2,118 @@ package sqldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btclog/v2"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/source/httpfs"
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
)
|
||||
|
||||
var (
|
||||
// migrationConfig defines a list of migrations to be applied to the
|
||||
// database. Each migration is assigned a version number, determining
|
||||
// its execution order.
|
||||
// The schema version, tracked by golang-migrate, ensures migrations are
|
||||
// applied to the correct schema. For migrations involving only schema
|
||||
// changes, the migration function can be left nil. For custom
|
||||
// migrations an implemented migration function is required.
|
||||
//
|
||||
// NOTE: The migration function may have runtime dependencies, which
|
||||
// must be injected during runtime.
|
||||
migrationConfig = []MigrationConfig{
|
||||
{
|
||||
Name: "000001_invoices",
|
||||
Version: 1,
|
||||
SchemaVersion: 1,
|
||||
},
|
||||
{
|
||||
Name: "000002_amp_invoices",
|
||||
Version: 2,
|
||||
SchemaVersion: 2,
|
||||
},
|
||||
{
|
||||
Name: "000003_invoice_events",
|
||||
Version: 3,
|
||||
SchemaVersion: 3,
|
||||
},
|
||||
{
|
||||
Name: "000004_invoice_expiry_fix",
|
||||
Version: 4,
|
||||
SchemaVersion: 4,
|
||||
},
|
||||
{
|
||||
Name: "000005_migration_tracker",
|
||||
Version: 5,
|
||||
SchemaVersion: 5,
|
||||
},
|
||||
{
|
||||
Name: "000006_invoice_migration",
|
||||
Version: 6,
|
||||
SchemaVersion: 6,
|
||||
},
|
||||
{
|
||||
Name: "kv_invoice_migration",
|
||||
Version: 7,
|
||||
SchemaVersion: 6,
|
||||
// A migration function is may be attached to this
|
||||
// migration to migrate KV invoices to the native SQL
|
||||
// schema. This is optional and can be disabled by the
|
||||
// user if necessary.
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// MigrationConfig is a configuration struct that describes SQL migrations. Each
|
||||
// migration is associated with a specific schema version and a global database
|
||||
// version. Migrations are applied in the order of their global database
|
||||
// version. If a migration includes a non-nil MigrationFn, it is executed after
|
||||
// the SQL schema has been migrated to the corresponding schema version.
|
||||
type MigrationConfig struct {
|
||||
// Name is the name of the migration.
|
||||
Name string
|
||||
|
||||
// Version represents the "global" database version for this migration.
|
||||
// Unlike the schema version tracked by golang-migrate, it encompasses
|
||||
// all migrations, including those managed by golang-migrate as well
|
||||
// as custom in-code migrations.
|
||||
Version int
|
||||
|
||||
// SchemaVersion represents the schema version tracked by golang-migrate
|
||||
// at which the migration is applied.
|
||||
SchemaVersion int
|
||||
|
||||
// MigrationFn is the function executed for custom migrations at the
|
||||
// specified version. It is used to handle migrations that cannot be
|
||||
// performed through SQL alone. If set to nil, no custom migration is
|
||||
// applied.
|
||||
MigrationFn func(tx *sqlc.Queries) error
|
||||
}
|
||||
|
||||
// MigrationTarget is a functional option that can be passed to applyMigrations
|
||||
// to specify a target version to migrate to.
|
||||
type MigrationTarget func(mig *migrate.Migrate) error
|
||||
|
||||
// MigrationExecutor is an interface that abstracts the migration functionality.
|
||||
type MigrationExecutor interface {
|
||||
// ExecuteMigrations runs database migrations up to the specified target
|
||||
// version or all migrations if no target is specified. A migration may
|
||||
// include a schema change, a custom migration function, or both.
|
||||
// Developers must ensure that migrations are defined in the correct
|
||||
// order. Migration details are stored in the global variable
|
||||
// migrationConfig.
|
||||
ExecuteMigrations(target MigrationTarget) error
|
||||
}
|
||||
|
||||
var (
|
||||
// TargetLatest is a MigrationTarget that migrates to the latest
|
||||
// version available.
|
||||
@ -34,6 +130,14 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// GetMigrations returns a copy of the migration configuration.
|
||||
func GetMigrations() []MigrationConfig {
|
||||
migrations := make([]MigrationConfig, len(migrationConfig))
|
||||
copy(migrations, migrationConfig)
|
||||
|
||||
return migrations
|
||||
}
|
||||
|
||||
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
|
||||
// used to log migrations.
|
||||
type migrationLogger struct {
|
||||
@ -216,3 +320,117 @@ func (t *replacerFile) Close() error {
|
||||
// instance, so there's nothing to do for us here.
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrationTxOptions is the implementation of the TxOptions interface for
|
||||
// migration transactions.
|
||||
type MigrationTxOptions struct {
|
||||
}
|
||||
|
||||
// ReadOnly returns false to indicate that migration transactions are not read
|
||||
// only.
|
||||
func (m *MigrationTxOptions) ReadOnly() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ApplyMigrations applies the provided migrations to the database in sequence.
|
||||
// It ensures migrations are executed in the correct order, applying both custom
|
||||
// migration functions and SQL migrations as needed.
|
||||
func ApplyMigrations(ctx context.Context, db *BaseDB,
|
||||
migrator MigrationExecutor, migrations []MigrationConfig) error {
|
||||
|
||||
// Ensure that the migrations are sorted by version.
|
||||
for i := 0; i < len(migrations); i++ {
|
||||
if migrations[i].Version != i+1 {
|
||||
return fmt.Errorf("migration version %d is out of "+
|
||||
"order. Expected %d", migrations[i].Version,
|
||||
i+1)
|
||||
}
|
||||
}
|
||||
// Construct a transaction executor to apply custom migrations.
|
||||
executor := NewTransactionExecutor(db, func(tx *sql.Tx) *sqlc.Queries {
|
||||
return db.WithTx(tx)
|
||||
})
|
||||
|
||||
currentVersion := 0
|
||||
version, err := db.GetDatabaseVersion(ctx)
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting current database "+
|
||||
"version: %w", err)
|
||||
}
|
||||
|
||||
currentVersion = int(version)
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
if migration.Version <= currentVersion {
|
||||
log.Infof("Skipping migration '%s' (version %d) as it "+
|
||||
"has already been applied", migration.Name,
|
||||
migration.Version)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("Migrating SQL schema to version %d",
|
||||
migration.SchemaVersion)
|
||||
|
||||
// Execute SQL schema migrations up to the target version.
|
||||
err = migrator.ExecuteMigrations(
|
||||
TargetVersion(uint(migration.SchemaVersion)),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error executing schema migrations "+
|
||||
"to target version %d: %w",
|
||||
migration.SchemaVersion, err)
|
||||
}
|
||||
|
||||
var opts MigrationTxOptions
|
||||
|
||||
// Run the custom migration as a transaction to ensure
|
||||
// atomicity. If successful, mark the migration as complete in
|
||||
// the migration tracker table.
|
||||
err = executor.ExecTx(ctx, &opts, func(tx *sqlc.Queries) error {
|
||||
// Apply the migration function if one is provided.
|
||||
if migration.MigrationFn != nil {
|
||||
log.Infof("Applying custom migration '%v' "+
|
||||
"(version %d) to schema version %d",
|
||||
migration.Name, migration.Version,
|
||||
migration.SchemaVersion)
|
||||
|
||||
err = migration.MigrationFn(tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error applying "+
|
||||
"migration '%v' (version %d) "+
|
||||
"to schema version %d: %w",
|
||||
migration.Name,
|
||||
migration.Version,
|
||||
migration.SchemaVersion, err)
|
||||
}
|
||||
|
||||
log.Infof("Migration '%v' (version %d) "+
|
||||
"applied ", migration.Name,
|
||||
migration.Version)
|
||||
}
|
||||
|
||||
// Mark the migration as complete by adding the version
|
||||
// to the migration tracker table along with the current
|
||||
// timestamp.
|
||||
err = tx.SetMigration(ctx, sqlc.SetMigrationParams{
|
||||
Version: int32(migration.Version),
|
||||
MigrationTime: time.Now(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting migration "+
|
||||
"version %d: %w", migration.Version,
|
||||
err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -2,8 +2,15 @@ package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
|
||||
sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite"
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -152,3 +159,296 @@ func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, invoices)
|
||||
}
|
||||
|
||||
// TestCustomMigration tests that a custom in-code migrations are correctly
|
||||
// executed during the migration process.
|
||||
func TestCustomMigration(t *testing.T) {
|
||||
var customMigrationLog []string
|
||||
|
||||
logMigration := func(name string) {
|
||||
customMigrationLog = append(customMigrationLog, name)
|
||||
}
|
||||
|
||||
// Some migrations to use for both the failure and success tests. Note
|
||||
// that the migrations are not in order to test that they are executed
|
||||
// in the correct order.
|
||||
migrations := []MigrationConfig{
|
||||
{
|
||||
Name: "1",
|
||||
Version: 1,
|
||||
SchemaVersion: 1,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("1")
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "2",
|
||||
Version: 2,
|
||||
SchemaVersion: 1,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("2")
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "3",
|
||||
Version: 3,
|
||||
SchemaVersion: 2,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("3")
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
migrations []MigrationConfig
|
||||
expectedSuccess bool
|
||||
expectedMigrationLog []string
|
||||
expectedSchemaVersion int
|
||||
expectedVersion int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
migrations: migrations,
|
||||
expectedSuccess: true,
|
||||
expectedMigrationLog: []string{"1", "2", "3"},
|
||||
expectedSchemaVersion: 2,
|
||||
expectedVersion: 3,
|
||||
},
|
||||
{
|
||||
name: "unordered migrations",
|
||||
migrations: append([]MigrationConfig{
|
||||
{
|
||||
Name: "4",
|
||||
Version: 4,
|
||||
SchemaVersion: 3,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("4")
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}, migrations...),
|
||||
expectedSuccess: false,
|
||||
expectedMigrationLog: nil,
|
||||
expectedSchemaVersion: 0,
|
||||
},
|
||||
{
|
||||
name: "failure of migration 4",
|
||||
migrations: append(migrations, MigrationConfig{
|
||||
Name: "4",
|
||||
Version: 4,
|
||||
SchemaVersion: 3,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
return fmt.Errorf("migration 4 failed")
|
||||
},
|
||||
}),
|
||||
expectedSuccess: false,
|
||||
expectedMigrationLog: []string{"1", "2", "3"},
|
||||
// Since schema migration is a separate step we expect
|
||||
// that migrating up to 3 succeeded.
|
||||
expectedSchemaVersion: 3,
|
||||
// We still remain on version 3 though.
|
||||
expectedVersion: 3,
|
||||
},
|
||||
{
|
||||
name: "success of migration 4",
|
||||
migrations: append(migrations, MigrationConfig{
|
||||
Name: "4",
|
||||
Version: 4,
|
||||
SchemaVersion: 3,
|
||||
MigrationFn: func(*sqlc.Queries) error {
|
||||
logMigration("4")
|
||||
|
||||
return nil
|
||||
},
|
||||
}),
|
||||
expectedSuccess: true,
|
||||
expectedMigrationLog: []string{"1", "2", "3", "4"},
|
||||
expectedSchemaVersion: 3,
|
||||
expectedVersion: 4,
|
||||
},
|
||||
}
|
||||
|
||||
ctxb := context.Background()
|
||||
for _, test := range tests {
|
||||
// checkSchemaVersion checks the database schema version against
|
||||
// the expected version.
|
||||
getSchemaVersion := func(t *testing.T,
|
||||
driver database.Driver, dbName string) int {
|
||||
|
||||
sqlMigrate, err := migrate.NewWithInstance(
|
||||
"migrations", nil, dbName, driver,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
version, _, err := sqlMigrate.Version()
|
||||
if err != migrate.ErrNilVersion {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return int(version)
|
||||
}
|
||||
|
||||
t.Run("SQLite "+test.name, func(t *testing.T) {
|
||||
customMigrationLog = nil
|
||||
|
||||
// First instantiate the database and run the migrations
|
||||
// including the custom migrations.
|
||||
t.Logf("Creating new SQLite DB for testing migrations")
|
||||
|
||||
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
|
||||
var (
|
||||
db *SqliteStore
|
||||
err error
|
||||
)
|
||||
|
||||
// Run the migration 3 times to test that the migrations
|
||||
// are idempotent.
|
||||
for i := 0; i < 3; i++ {
|
||||
db, err = NewSqliteStore(&SqliteConfig{
|
||||
SkipMigrations: false,
|
||||
}, dbFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
dbToCleanup := db.DB
|
||||
t.Cleanup(func() {
|
||||
require.NoError(
|
||||
t, dbToCleanup.Close(),
|
||||
)
|
||||
})
|
||||
|
||||
err = db.ApplyAllMigrations(
|
||||
ctxb, test.migrations,
|
||||
)
|
||||
if test.expectedSuccess {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
// Also repoen the DB without migrations
|
||||
// so we can read versions.
|
||||
db, err = NewSqliteStore(&SqliteConfig{
|
||||
SkipMigrations: true,
|
||||
}, dbFileName)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t,
|
||||
test.expectedMigrationLog,
|
||||
customMigrationLog,
|
||||
)
|
||||
|
||||
// Create the migration executor to be able to
|
||||
// query the current schema version.
|
||||
driver, err := sqlite_migrate.WithInstance(
|
||||
db.DB, &sqlite_migrate.Config{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(
|
||||
t, test.expectedSchemaVersion,
|
||||
getSchemaVersion(t, driver, ""),
|
||||
)
|
||||
|
||||
// Check the migraton version in the database.
|
||||
version, err := db.GetDatabaseVersion(ctxb)
|
||||
if test.expectedSchemaVersion != 0 {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Equal(t, sql.ErrNoRows, err)
|
||||
}
|
||||
|
||||
require.Equal(
|
||||
t, test.expectedVersion, int(version),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Postgres "+test.name, func(t *testing.T) {
|
||||
customMigrationLog = nil
|
||||
|
||||
// First create a temporary Postgres database to run
|
||||
// the migrations on.
|
||||
fixture := NewTestPgFixture(
|
||||
t, DefaultPostgresFixtureLifetime,
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
fixture.TearDown(t)
|
||||
})
|
||||
|
||||
dbName := randomDBName(t)
|
||||
|
||||
// Next instantiate the database and run the migrations
|
||||
// including the custom migrations.
|
||||
t.Logf("Creating new Postgres DB '%s' for testing "+
|
||||
"migrations", dbName)
|
||||
|
||||
_, err := fixture.db.ExecContext(
|
||||
context.Background(), "CREATE DATABASE "+dbName,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := fixture.GetConfig(dbName)
|
||||
var db *PostgresStore
|
||||
|
||||
// Run the migration 3 times to test that the migrations
|
||||
// are idempotent.
|
||||
for i := 0; i < 3; i++ {
|
||||
cfg.SkipMigrations = false
|
||||
db, err = NewPostgresStore(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.ApplyAllMigrations(
|
||||
ctxb, test.migrations,
|
||||
)
|
||||
if test.expectedSuccess {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
// Also repoen the DB without migrations
|
||||
// so we can read versions.
|
||||
cfg.SkipMigrations = true
|
||||
db, err = NewPostgresStore(cfg)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t,
|
||||
test.expectedMigrationLog,
|
||||
customMigrationLog,
|
||||
)
|
||||
|
||||
// Create the migration executor to be able to
|
||||
// query the current version.
|
||||
driver, err := pgx_migrate.WithInstance(
|
||||
db.DB, &pgx_migrate.Config{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(
|
||||
t, test.expectedSchemaVersion,
|
||||
getSchemaVersion(t, driver, ""),
|
||||
)
|
||||
|
||||
// Check the migraton version in the database.
|
||||
version, err := db.GetDatabaseVersion(ctxb)
|
||||
if test.expectedSchemaVersion != 0 {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Equal(t, sql.ErrNoRows, err)
|
||||
}
|
||||
|
||||
require.Equal(
|
||||
t, test.expectedVersion, int(version),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,15 @@
|
||||
|
||||
package sqldb
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
// Make sure SqliteStore implements the DB interface.
|
||||
_ DB = (*SqliteStore)(nil)
|
||||
)
|
||||
|
||||
// SqliteStore is a database store implementation that uses a sqlite backend.
|
||||
type SqliteStore struct {
|
||||
@ -16,3 +24,17 @@ type SqliteStore struct {
|
||||
func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
||||
return nil, fmt.Errorf("SQLite backend not supported in WebAssembly")
|
||||
}
|
||||
|
||||
// GetBaseDB returns the underlying BaseDB instance for the SQLite store.
|
||||
// It is a trivial helper method to comply with the sqldb.DB interface.
|
||||
func (s *SqliteStore) GetBaseDB() *BaseDB {
|
||||
return s.BaseDB
|
||||
}
|
||||
|
||||
// ApplyAllMigrations applies both the SQLC and custom in-code migrations to
|
||||
// the SQLite database.
|
||||
func (s *SqliteStore) ApplyAllMigrations(context.Context,
|
||||
[]MigrationConfig) error {
|
||||
|
||||
return fmt.Errorf("SQLite backend not supported in WebAssembly")
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@ -28,10 +29,15 @@ var (
|
||||
// has some differences.
|
||||
postgresSchemaReplacements = map[string]string{
|
||||
"BLOB": "BYTEA",
|
||||
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
|
||||
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
|
||||
"INTEGER PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
|
||||
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
|
||||
}
|
||||
|
||||
// Make sure PostgresStore implements the MigrationExecutor interface.
|
||||
_ MigrationExecutor = (*PostgresStore)(nil)
|
||||
|
||||
// Make sure PostgresStore implements the DB interface.
|
||||
_ DB = (*PostgresStore)(nil)
|
||||
)
|
||||
|
||||
// replacePasswordInDSN takes a DSN string and returns it with the password
|
||||
@ -92,40 +98,64 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
|
||||
}
|
||||
log.Infof("Using SQL database '%s'", sanitizedDSN)
|
||||
|
||||
rawDB, err := sql.Open("pgx", cfg.Dsn)
|
||||
db, err := sql.Open("pgx", cfg.Dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the migration tracker table exists before running migrations.
|
||||
// This table tracks migration progress and ensures compatibility with
|
||||
// SQLC query generation. If the table is already created by an SQLC
|
||||
// migration, this operation becomes a no-op.
|
||||
migrationTrackerSQL := `
|
||||
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||
version INTEGER UNIQUE NOT NULL,
|
||||
migration_time TIMESTAMP NOT NULL
|
||||
);`
|
||||
|
||||
_, err = db.Exec(migrationTrackerSQL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating migration tracker: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
maxConns := defaultMaxConns
|
||||
if cfg.MaxConnections > 0 {
|
||||
maxConns = cfg.MaxConnections
|
||||
}
|
||||
|
||||
rawDB.SetMaxOpenConns(maxConns)
|
||||
rawDB.SetMaxIdleConns(maxConns)
|
||||
rawDB.SetConnMaxLifetime(connIdleLifetime)
|
||||
db.SetMaxOpenConns(maxConns)
|
||||
db.SetMaxIdleConns(maxConns)
|
||||
db.SetConnMaxLifetime(connIdleLifetime)
|
||||
|
||||
queries := sqlc.New(rawDB)
|
||||
queries := sqlc.New(db)
|
||||
|
||||
s := &PostgresStore{
|
||||
return &PostgresStore{
|
||||
cfg: cfg,
|
||||
BaseDB: &BaseDB{
|
||||
DB: rawDB,
|
||||
DB: db,
|
||||
Queries: queries,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetBaseDB returns the underlying BaseDB instance for the Postgres store.
|
||||
// It is a trivial helper method to comply with the sqldb.DB interface.
|
||||
func (s *PostgresStore) GetBaseDB() *BaseDB {
|
||||
return s.BaseDB
|
||||
}
|
||||
|
||||
// ApplyAllMigrations applies both the SQLC and custom in-code migrations to the
|
||||
// Postgres database.
|
||||
func (s *PostgresStore) ApplyAllMigrations(ctx context.Context,
|
||||
migrations []MigrationConfig) error {
|
||||
|
||||
// Execute migrations unless configured to skip them.
|
||||
if !cfg.SkipMigrations {
|
||||
err := s.ExecuteMigrations(TargetLatest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error executing migrations: %w",
|
||||
err)
|
||||
}
|
||||
if s.cfg.SkipMigrations {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s, nil
|
||||
return ApplyMigrations(ctx, s.BaseDB, s, migrations)
|
||||
}
|
||||
|
||||
// ExecuteMigrations runs migrations for the Postgres database, depending on the
|
||||
|
@ -59,7 +59,7 @@ func NewTestPgFixture(t *testing.T, expiry time.Duration) *TestPgFixture {
|
||||
"postgres",
|
||||
"-c", "log_statement=all",
|
||||
"-c", "log_destination=stderr",
|
||||
"-c", "max_connections=1000",
|
||||
"-c", "max_connections=5000",
|
||||
},
|
||||
}, func(config *docker.HostConfig) {
|
||||
// Set AutoRemove to true so that stopped container goes away
|
||||
@ -151,6 +151,10 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore {
|
||||
store, err := NewPostgresStore(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, store.ApplyAllMigrations(
|
||||
context.Background(), GetMigrations()),
|
||||
)
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
|
@ -235,6 +235,35 @@ func (q *Queries) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, err
|
||||
return invoice_id, err
|
||||
}
|
||||
|
||||
const insertAMPSubInvoice = `-- name: InsertAMPSubInvoice :exec
|
||||
INSERT INTO amp_sub_invoices (
|
||||
set_id, state, created_at, settled_at, settle_index, invoice_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6
|
||||
)
|
||||
`
|
||||
|
||||
type InsertAMPSubInvoiceParams struct {
|
||||
SetID []byte
|
||||
State int16
|
||||
CreatedAt time.Time
|
||||
SettledAt sql.NullTime
|
||||
SettleIndex sql.NullInt64
|
||||
InvoiceID int64
|
||||
}
|
||||
|
||||
func (q *Queries) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error {
|
||||
_, err := q.db.ExecContext(ctx, insertAMPSubInvoice,
|
||||
arg.SetID,
|
||||
arg.State,
|
||||
arg.CreatedAt,
|
||||
arg.SettledAt,
|
||||
arg.SettleIndex,
|
||||
arg.InvoiceID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
const insertAMPSubInvoiceHTLC = `-- name: InsertAMPSubInvoiceHTLC :exec
|
||||
INSERT INTO amp_sub_invoice_htlcs (
|
||||
invoice_id, set_id, htlc_id, root_share, child_index, hash, preimage
|
||||
|
@ -11,6 +11,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const clearKVInvoiceHashIndex = `-- name: ClearKVInvoiceHashIndex :exec
|
||||
DELETE FROM invoice_payment_hashes
|
||||
`
|
||||
|
||||
func (q *Queries) ClearKVInvoiceHashIndex(ctx context.Context) error {
|
||||
_, err := q.db.ExecContext(ctx, clearKVInvoiceHashIndex)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteCanceledInvoices = `-- name: DeleteCanceledInvoices :execresult
|
||||
DELETE
|
||||
FROM invoices
|
||||
@ -182,11 +191,8 @@ WHERE (
|
||||
i.hash = $3 OR
|
||||
$3 IS NULL
|
||||
) AND (
|
||||
i.preimage = $4 OR
|
||||
i.payment_addr = $4 OR
|
||||
$4 IS NULL
|
||||
) AND (
|
||||
i.payment_addr = $5 OR
|
||||
$5 IS NULL
|
||||
)
|
||||
GROUP BY i.id
|
||||
LIMIT 2
|
||||
@ -196,7 +202,6 @@ type GetInvoiceParams struct {
|
||||
SetID []byte
|
||||
AddIndex sql.NullInt64
|
||||
Hash []byte
|
||||
Preimage []byte
|
||||
PaymentAddr []byte
|
||||
}
|
||||
|
||||
@ -208,7 +213,6 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
|
||||
arg.SetID,
|
||||
arg.AddIndex,
|
||||
arg.Hash,
|
||||
arg.Preimage,
|
||||
arg.PaymentAddr,
|
||||
)
|
||||
if err != nil {
|
||||
@ -251,6 +255,38 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getInvoiceByHash = `-- name: GetInvoiceByHash :one
|
||||
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
|
||||
FROM invoices i
|
||||
WHERE i.hash = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error) {
|
||||
row := q.db.QueryRowContext(ctx, getInvoiceByHash, hash)
|
||||
var i Invoice
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Hash,
|
||||
&i.Preimage,
|
||||
&i.SettleIndex,
|
||||
&i.SettledAt,
|
||||
&i.Memo,
|
||||
&i.AmountMsat,
|
||||
&i.CltvDelta,
|
||||
&i.Expiry,
|
||||
&i.PaymentAddr,
|
||||
&i.PaymentRequest,
|
||||
&i.PaymentRequestHash,
|
||||
&i.State,
|
||||
&i.AmountPaidMsat,
|
||||
&i.IsAmp,
|
||||
&i.IsHodl,
|
||||
&i.IsKeysend,
|
||||
&i.CreatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many
|
||||
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
|
||||
FROM invoices i
|
||||
@ -405,6 +441,19 @@ func (q *Queries) GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]Invoi
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getKVInvoicePaymentHashByAddIndex = `-- name: GetKVInvoicePaymentHashByAddIndex :one
|
||||
SELECT hash
|
||||
FROM invoice_payment_hashes
|
||||
WHERE add_index = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error) {
|
||||
row := q.db.QueryRowContext(ctx, getKVInvoicePaymentHashByAddIndex, addIndex)
|
||||
var hash []byte
|
||||
err := row.Scan(&hash)
|
||||
return hash, err
|
||||
}
|
||||
|
||||
const insertInvoice = `-- name: InsertInvoice :one
|
||||
INSERT INTO invoices (
|
||||
hash, preimage, memo, amount_msat, cltv_delta, expiry, payment_addr,
|
||||
@ -533,6 +582,79 @@ func (q *Queries) InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertI
|
||||
return err
|
||||
}
|
||||
|
||||
const insertKVInvoiceKeyAndAddIndex = `-- name: InsertKVInvoiceKeyAndAddIndex :exec
|
||||
INSERT INTO invoice_payment_hashes (
|
||||
id, add_index
|
||||
) VALUES (
|
||||
$1, $2
|
||||
)
|
||||
`
|
||||
|
||||
type InsertKVInvoiceKeyAndAddIndexParams struct {
|
||||
ID int64
|
||||
AddIndex int64
|
||||
}
|
||||
|
||||
func (q *Queries) InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error {
|
||||
_, err := q.db.ExecContext(ctx, insertKVInvoiceKeyAndAddIndex, arg.ID, arg.AddIndex)
|
||||
return err
|
||||
}
|
||||
|
||||
const insertMigratedInvoice = `-- name: InsertMigratedInvoice :one
|
||||
INSERT INTO invoices (
|
||||
hash, preimage, settle_index, settled_at, memo, amount_msat, cltv_delta,
|
||||
expiry, payment_addr, payment_request, payment_request_hash, state,
|
||||
amount_paid_msat, is_amp, is_hodl, is_keysend, created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
|
||||
) RETURNING id
|
||||
`
|
||||
|
||||
type InsertMigratedInvoiceParams struct {
|
||||
Hash []byte
|
||||
Preimage []byte
|
||||
SettleIndex sql.NullInt64
|
||||
SettledAt sql.NullTime
|
||||
Memo sql.NullString
|
||||
AmountMsat int64
|
||||
CltvDelta sql.NullInt32
|
||||
Expiry int32
|
||||
PaymentAddr []byte
|
||||
PaymentRequest sql.NullString
|
||||
PaymentRequestHash []byte
|
||||
State int16
|
||||
AmountPaidMsat int64
|
||||
IsAmp bool
|
||||
IsHodl bool
|
||||
IsKeysend bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func (q *Queries) InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertMigratedInvoice,
|
||||
arg.Hash,
|
||||
arg.Preimage,
|
||||
arg.SettleIndex,
|
||||
arg.SettledAt,
|
||||
arg.Memo,
|
||||
arg.AmountMsat,
|
||||
arg.CltvDelta,
|
||||
arg.Expiry,
|
||||
arg.PaymentAddr,
|
||||
arg.PaymentRequest,
|
||||
arg.PaymentRequestHash,
|
||||
arg.State,
|
||||
arg.AmountPaidMsat,
|
||||
arg.IsAmp,
|
||||
arg.IsHodl,
|
||||
arg.IsKeysend,
|
||||
arg.CreatedAt,
|
||||
)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
|
||||
const nextInvoiceSettleIndex = `-- name: NextInvoiceSettleIndex :one
|
||||
UPDATE invoice_sequences SET current_value = current_value + 1
|
||||
WHERE name = 'settle_index'
|
||||
@ -546,6 +668,22 @@ func (q *Queries) NextInvoiceSettleIndex(ctx context.Context) (int64, error) {
|
||||
return current_value, err
|
||||
}
|
||||
|
||||
const setKVInvoicePaymentHash = `-- name: SetKVInvoicePaymentHash :exec
|
||||
UPDATE invoice_payment_hashes
|
||||
SET hash = $2
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
type SetKVInvoicePaymentHashParams struct {
|
||||
ID int64
|
||||
Hash []byte
|
||||
}
|
||||
|
||||
func (q *Queries) SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error {
|
||||
_, err := q.db.ExecContext(ctx, setKVInvoicePaymentHash, arg.ID, arg.Hash)
|
||||
return err
|
||||
}
|
||||
|
||||
const updateInvoiceAmountPaid = `-- name: UpdateInvoiceAmountPaid :execresult
|
||||
UPDATE invoices
|
||||
SET amount_paid_msat = $2
|
||||
|
60
sqldb/sqlc/migration.sql.go
Normal file
60
sqldb/sqlc/migration.sql.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.25.0
|
||||
// source: migration.sql
|
||||
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const getDatabaseVersion = `-- name: GetDatabaseVersion :one
|
||||
SELECT
|
||||
version
|
||||
FROM
|
||||
migration_tracker
|
||||
ORDER BY
|
||||
version DESC
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
func (q *Queries) GetDatabaseVersion(ctx context.Context) (int32, error) {
|
||||
row := q.db.QueryRowContext(ctx, getDatabaseVersion)
|
||||
var version int32
|
||||
err := row.Scan(&version)
|
||||
return version, err
|
||||
}
|
||||
|
||||
const getMigration = `-- name: GetMigration :one
|
||||
SELECT
|
||||
migration_time
|
||||
FROM
|
||||
migration_tracker
|
||||
WHERE
|
||||
version = $1
|
||||
`
|
||||
|
||||
func (q *Queries) GetMigration(ctx context.Context, version int32) (time.Time, error) {
|
||||
row := q.db.QueryRowContext(ctx, getMigration, version)
|
||||
var migration_time time.Time
|
||||
err := row.Scan(&migration_time)
|
||||
return migration_time, err
|
||||
}
|
||||
|
||||
const setMigration = `-- name: SetMigration :exec
|
||||
INSERT INTO
|
||||
migration_tracker (version, migration_time)
|
||||
VALUES ($1, $2)
|
||||
`
|
||||
|
||||
type SetMigrationParams struct {
|
||||
Version int32
|
||||
MigrationTime time.Time
|
||||
}
|
||||
|
||||
func (q *Queries) SetMigration(ctx context.Context, arg SetMigrationParams) error {
|
||||
_, err := q.db.ExecContext(ctx, setMigration, arg.Version, arg.MigrationTime)
|
||||
return err
|
||||
}
|
@ -11,7 +11,7 @@ INSERT INTO invoice_sequences(name, current_value) VALUES ('settle_index', 0);
|
||||
-- invoices table contains all the information shared by all the invoice types.
|
||||
CREATE TABLE IF NOT EXISTS invoices (
|
||||
-- The id of the invoice. Translates to the AddIndex.
|
||||
id BIGINT PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY,
|
||||
|
||||
-- The hash for this invoice. The invoice hash will always identify that
|
||||
-- invoice.
|
||||
@ -102,7 +102,7 @@ CREATE INDEX IF NOT EXISTS invoice_feature_invoice_id_idx ON invoice_features(in
|
||||
CREATE TABLE IF NOT EXISTS invoice_htlcs (
|
||||
-- The id for this htlc. Used in foreign keys instead of the
|
||||
-- htlc_id/chan_id combination.
|
||||
id BIGINT PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY,
|
||||
|
||||
-- Short chan id indicating the htlc's origin. uint64 stored as text.
|
||||
chan_id TEXT NOT NULL,
|
||||
|
@ -29,7 +29,7 @@ VALUES
|
||||
-- AMP sub invoices. This table can be used to create a historical view of what
|
||||
-- happened to the node's invoices.
|
||||
CREATE TABLE IF NOT EXISTS invoice_events (
|
||||
id BIGINT PRIMARY KEY,
|
||||
id INTEGER PRIMARY KEY,
|
||||
|
||||
-- added_at is the timestamp when this event was added.
|
||||
added_at TIMESTAMP NOT NULL,
|
||||
|
1
sqldb/sqlc/migrations/000005_migration_tracker.down.sql
Normal file
1
sqldb/sqlc/migrations/000005_migration_tracker.down.sql
Normal file
@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS migration_tracker;
|
17
sqldb/sqlc/migrations/000005_migration_tracker.up.sql
Normal file
17
sqldb/sqlc/migrations/000005_migration_tracker.up.sql
Normal file
@ -0,0 +1,17 @@
|
||||
-- The migration_tracker table keeps track of migrations that have been applied
|
||||
-- to the database. This table ensures that migrations are idempotent and are
|
||||
-- only run once. It tracks a global database version that encompasses both
|
||||
-- schema migrations handled by golang-migrate and custom in-code migrations
|
||||
-- for more complex data conversions that cannot be expressed in pure SQL.
|
||||
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||
-- version is the global version of the migration. Note that we
|
||||
-- intentionally don't set it as PRIMARY KEY as it'd auto increment on
|
||||
-- SQLite and our sqlc workflow will replace it with an auto
|
||||
-- incrementing SERIAL on Postgres too. UNIQUE achieves the same effect
|
||||
-- without the auto increment.
|
||||
version INTEGER UNIQUE NOT NULL,
|
||||
|
||||
-- migration_time is the timestamp at which the migration was run.
|
||||
migration_time TIMESTAMP NOT NULL
|
||||
);
|
||||
|
1
sqldb/sqlc/migrations/000006_invoice_migration.down.sql
Normal file
1
sqldb/sqlc/migrations/000006_invoice_migration.down.sql
Normal file
@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS invoice_payment_hashes;
|
17
sqldb/sqlc/migrations/000006_invoice_migration.up.sql
Normal file
17
sqldb/sqlc/migrations/000006_invoice_migration.up.sql
Normal file
@ -0,0 +1,17 @@
|
||||
-- invoice_payment_hashes table contains the hash of the invoices. This table
|
||||
-- is used during KV to SQL invoice migration as in our KV representation we
|
||||
-- don't have a mapping from hash to add index.
|
||||
CREATE TABLE IF NOT EXISTS invoice_payment_hashes (
|
||||
-- id represents is the key of the invoice in the KV store.
|
||||
id INTEGER PRIMARY KEY,
|
||||
|
||||
-- add_index is the KV add index of the invoice.
|
||||
add_index BIGINT NOT NULL,
|
||||
|
||||
-- hash is the payment hash for this invoice.
|
||||
hash BLOB
|
||||
);
|
||||
|
||||
-- Create an indexes on the add_index and hash columns to speed up lookups.
|
||||
CREATE INDEX IF NOT EXISTS invoice_payment_hashes_add_index_idx ON invoice_payment_hashes(add_index);
|
||||
CREATE INDEX IF NOT EXISTS invoice_payment_hashes_hash_idx ON invoice_payment_hashes(hash);
|
@ -58,7 +58,7 @@ type InvoiceEvent struct {
|
||||
}
|
||||
|
||||
type InvoiceEventType struct {
|
||||
ID int32
|
||||
ID int64
|
||||
Description string
|
||||
}
|
||||
|
||||
@ -87,7 +87,18 @@ type InvoiceHtlcCustomRecord struct {
|
||||
HtlcID int64
|
||||
}
|
||||
|
||||
type InvoicePaymentHash struct {
|
||||
ID int64
|
||||
AddIndex int64
|
||||
Hash []byte
|
||||
}
|
||||
|
||||
type InvoiceSequence struct {
|
||||
Name string
|
||||
CurrentValue int64
|
||||
}
|
||||
|
||||
type MigrationTracker struct {
|
||||
Version int32
|
||||
MigrationTime time.Time
|
||||
}
|
||||
|
@ -7,9 +7,11 @@ package sqlc
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
ClearKVInvoiceHashIndex(ctx context.Context) error
|
||||
DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
|
||||
DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error)
|
||||
FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error)
|
||||
@ -17,19 +19,26 @@ type Querier interface {
|
||||
FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error)
|
||||
FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error)
|
||||
GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error)
|
||||
GetDatabaseVersion(ctx context.Context) (int32, error)
|
||||
// This method may return more than one invoice if filter using multiple fields
|
||||
// from different invoices. It is the caller's responsibility to ensure that
|
||||
// we bubble up an error in those cases.
|
||||
GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error)
|
||||
GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error)
|
||||
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error)
|
||||
GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error)
|
||||
GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error)
|
||||
GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error)
|
||||
GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error)
|
||||
GetMigration(ctx context.Context, version int32) (time.Time, error)
|
||||
InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error
|
||||
InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error
|
||||
InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error)
|
||||
InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error
|
||||
InsertInvoiceHTLC(ctx context.Context, arg InsertInvoiceHTLCParams) (int64, error)
|
||||
InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertInvoiceHTLCCustomRecordParams) error
|
||||
InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error
|
||||
InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error)
|
||||
NextInvoiceSettleIndex(ctx context.Context) (int64, error)
|
||||
OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error
|
||||
OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error
|
||||
@ -37,6 +46,8 @@ type Querier interface {
|
||||
OnInvoiceCanceled(ctx context.Context, arg OnInvoiceCanceledParams) error
|
||||
OnInvoiceCreated(ctx context.Context, arg OnInvoiceCreatedParams) error
|
||||
OnInvoiceSettled(ctx context.Context, arg OnInvoiceSettledParams) error
|
||||
SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error
|
||||
SetMigration(ctx context.Context, arg SetMigrationParams) error
|
||||
UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result, error)
|
||||
UpdateAMPSubInvoiceState(ctx context.Context, arg UpdateAMPSubInvoiceStateParams) error
|
||||
UpdateInvoiceAmountPaid(ctx context.Context, arg UpdateInvoiceAmountPaidParams) (sql.Result, error)
|
||||
|
@ -65,3 +65,11 @@ SET preimage = $5
|
||||
WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = (
|
||||
SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4
|
||||
);
|
||||
|
||||
-- name: InsertAMPSubInvoice :exec
|
||||
INSERT INTO amp_sub_invoices (
|
||||
set_id, state, created_at, settled_at, settle_index, invoice_id
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6
|
||||
);
|
||||
|
||||
|
@ -7,6 +7,16 @@ INSERT INTO invoices (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
|
||||
) RETURNING id;
|
||||
|
||||
-- name: InsertMigratedInvoice :one
|
||||
INSERT INTO invoices (
|
||||
hash, preimage, settle_index, settled_at, memo, amount_msat, cltv_delta,
|
||||
expiry, payment_addr, payment_request, payment_request_hash, state,
|
||||
amount_paid_msat, is_amp, is_hodl, is_keysend, created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
|
||||
) RETURNING id;
|
||||
|
||||
|
||||
-- name: InsertInvoiceFeature :exec
|
||||
INSERT INTO invoice_features (
|
||||
invoice_id, feature
|
||||
@ -37,9 +47,6 @@ WHERE (
|
||||
) AND (
|
||||
i.hash = sqlc.narg('hash') OR
|
||||
sqlc.narg('hash') IS NULL
|
||||
) AND (
|
||||
i.preimage = sqlc.narg('preimage') OR
|
||||
sqlc.narg('preimage') IS NULL
|
||||
) AND (
|
||||
i.payment_addr = sqlc.narg('payment_addr') OR
|
||||
sqlc.narg('payment_addr') IS NULL
|
||||
@ -47,6 +54,11 @@ WHERE (
|
||||
GROUP BY i.id
|
||||
LIMIT 2;
|
||||
|
||||
-- name: GetInvoiceByHash :one
|
||||
SELECT i.*
|
||||
FROM invoices i
|
||||
WHERE i.hash = $1;
|
||||
|
||||
-- name: GetInvoiceBySetID :many
|
||||
SELECT i.*
|
||||
FROM invoices i
|
||||
@ -169,3 +181,23 @@ INSERT INTO invoice_htlc_custom_records (
|
||||
SELECT ihcr.htlc_id, key, value
|
||||
FROM invoice_htlcs ih JOIN invoice_htlc_custom_records ihcr ON ih.id=ihcr.htlc_id
|
||||
WHERE ih.invoice_id = $1;
|
||||
|
||||
-- name: InsertKVInvoiceKeyAndAddIndex :exec
|
||||
INSERT INTO invoice_payment_hashes (
|
||||
id, add_index
|
||||
) VALUES (
|
||||
$1, $2
|
||||
);
|
||||
|
||||
-- name: SetKVInvoicePaymentHash :exec
|
||||
UPDATE invoice_payment_hashes
|
||||
SET hash = $2
|
||||
WHERE id = $1;
|
||||
|
||||
-- name: GetKVInvoicePaymentHashByAddIndex :one
|
||||
SELECT hash
|
||||
FROM invoice_payment_hashes
|
||||
WHERE add_index = $1;
|
||||
|
||||
-- name: ClearKVInvoiceHashIndex :exec
|
||||
DELETE FROM invoice_payment_hashes;
|
||||
|
21
sqldb/sqlc/queries/migration.sql
Normal file
21
sqldb/sqlc/queries/migration.sql
Normal file
@ -0,0 +1,21 @@
|
||||
-- name: SetMigration :exec
|
||||
INSERT INTO
|
||||
migration_tracker (version, migration_time)
|
||||
VALUES ($1, $2);
|
||||
|
||||
-- name: GetMigration :one
|
||||
SELECT
|
||||
migration_time
|
||||
FROM
|
||||
migration_tracker
|
||||
WHERE
|
||||
version = $1;
|
||||
|
||||
-- name: GetDatabaseVersion :one
|
||||
SELECT
|
||||
version
|
||||
FROM
|
||||
migration_tracker
|
||||
ORDER BY
|
||||
version DESC
|
||||
LIMIT 1;
|
@ -3,6 +3,7 @@
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@ -27,13 +28,16 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// sqliteSchemaReplacements is a map of schema strings that need to be
|
||||
// replaced for sqlite. This is needed because sqlite doesn't directly
|
||||
// support the BIGINT type for primary keys, so we need to replace it
|
||||
// with INTEGER.
|
||||
sqliteSchemaReplacements = map[string]string{
|
||||
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
|
||||
}
|
||||
// sqliteSchemaReplacements maps schema strings to their SQLite
|
||||
// compatible replacements. Currently, no replacements are needed as our
|
||||
// SQL schema definition files are designed for SQLite compatibility.
|
||||
sqliteSchemaReplacements = map[string]string{}
|
||||
|
||||
// Make sure SqliteStore implements the MigrationExecutor interface.
|
||||
_ MigrationExecutor = (*SqliteStore)(nil)
|
||||
|
||||
// Make sure SqliteStore implements the DB interface.
|
||||
_ DB = (*SqliteStore)(nil)
|
||||
)
|
||||
|
||||
// SqliteStore is a database store implementation that uses a sqlite backend.
|
||||
@ -102,6 +106,23 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the migration tracker table before starting migrations to
|
||||
// ensure it can be used to track migration progress. Note that a
|
||||
// corresponding SQLC migration also creates this table, making this
|
||||
// operation a no-op in that context. Its purpose is to ensure
|
||||
// compatibility with SQLC query generation.
|
||||
migrationTrackerSQL := `
|
||||
CREATE TABLE IF NOT EXISTS migration_tracker (
|
||||
version INTEGER UNIQUE NOT NULL,
|
||||
migration_time TIMESTAMP NOT NULL
|
||||
);`
|
||||
|
||||
_, err = db.Exec(migrationTrackerSQL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating migration tracker: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(defaultMaxConns)
|
||||
db.SetMaxIdleConns(defaultMaxConns)
|
||||
db.SetConnMaxLifetime(connIdleLifetime)
|
||||
@ -115,18 +136,28 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
|
||||
},
|
||||
}
|
||||
|
||||
// Execute migrations unless configured to skip them.
|
||||
if !cfg.SkipMigrations {
|
||||
if err := s.ExecuteMigrations(TargetLatest); err != nil {
|
||||
return nil, fmt.Errorf("error executing migrations: "+
|
||||
"%w", err)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GetBaseDB returns the underlying BaseDB instance for the SQLite store.
|
||||
// It is a trivial helper method to comply with the sqldb.DB interface.
|
||||
func (s *SqliteStore) GetBaseDB() *BaseDB {
|
||||
return s.BaseDB
|
||||
}
|
||||
|
||||
// ApplyAllMigrations applies both the SQLC and custom in-code migrations to the
|
||||
// SQLite database.
|
||||
func (s *SqliteStore) ApplyAllMigrations(ctx context.Context,
|
||||
migrations []MigrationConfig) error {
|
||||
|
||||
// Execute migrations unless configured to skip them.
|
||||
if s.cfg.SkipMigrations {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ApplyMigrations(ctx, s.BaseDB, s, migrations)
|
||||
}
|
||||
|
||||
// ExecuteMigrations runs migrations for the sqlite database, depending on the
|
||||
// target given, either all migrations or up to a given version.
|
||||
func (s *SqliteStore) ExecuteMigrations(target MigrationTarget) error {
|
||||
@ -160,6 +191,10 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore {
|
||||
}, dbFileName)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, sqlDB.ApplyAllMigrations(
|
||||
context.Background(), GetMigrations()),
|
||||
)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, sqlDB.DB.Close())
|
||||
})
|
||||
|
Loading…
x
Reference in New Issue
Block a user