Merge pull request #8831 from bhandras/sql-invoice-migration

invoices: migrate KV invoices to native SQL for users of KV SQL backends
This commit is contained in:
Oliver Gugger 2025-01-23 05:48:25 -06:00 committed by GitHub
commit 6cabc74c20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 2762 additions and 210 deletions

View File

@ -51,6 +51,7 @@ import (
"github.com/lightningnetwork/lnd/rpcperms"
"github.com/lightningnetwork/lnd/signal"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
"github.com/lightningnetwork/lnd/sweep"
"github.com/lightningnetwork/lnd/walletunlocker"
"github.com/lightningnetwork/lnd/watchtower"
@ -60,6 +61,16 @@ import (
"gopkg.in/macaroon-bakery.v2/bakery"
)
const (
// invoiceMigrationBatchSize is the number of invoices that will be
// migrated in a single batch.
invoiceMigrationBatchSize = 1000
// invoiceMigration is the version of the migration that will be used to
// migrate invoices from the kvdb to the sql database.
invoiceMigration = 7
)
// GrpcRegistrar is an interface that must be satisfied by an external subserver
// that wants to be able to register its own gRPC server onto lnd's main
// grpc.Server instance.
@ -932,10 +943,10 @@ type DatabaseInstances struct {
// the btcwallet's loader.
WalletDB btcwallet.LoaderOption
// NativeSQLStore is a pointer to a native SQL store that can be used
// for native SQL queries for tables that already support it. This may
// be nil if the use-native-sql flag was not set.
NativeSQLStore *sqldb.BaseDB
// NativeSQLStore holds a reference to the native SQL store that can
// be used for native SQL queries for tables that already support it.
// This may be nil if the use-native-sql flag was not set.
NativeSQLStore sqldb.DB
}
// DefaultDatabaseBuilder is a type that builds the default database backends
@ -1038,7 +1049,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
if err != nil {
cleanUp()
err := fmt.Errorf("unable to open graph DB: %w", err)
err = fmt.Errorf("unable to open graph DB: %w", err)
d.logger.Error(err)
return nil, nil, err
@ -1072,51 +1083,69 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
case err != nil:
cleanUp()
err := fmt.Errorf("unable to open graph DB: %w", err)
err = fmt.Errorf("unable to open graph DB: %w", err)
d.logger.Error(err)
return nil, nil, err
}
// Instantiate a native SQL invoice store if the flag is set.
// Instantiate a native SQL store if the flag is set.
if d.cfg.DB.UseNativeSQL {
// KV invoice db resides in the same database as the channel
// state DB. Let's query the database to see if we have any
// invoices there. If we do, we won't allow the user to start
// lnd with native SQL enabled, as we don't currently migrate
// the invoices to the new database schema.
invoiceSlice, err := dbs.ChanStateDB.QueryInvoices(
ctx, invoices.InvoiceQuery{
NumMaxInvoices: 1,
},
)
if err != nil {
cleanUp()
d.logger.Errorf("Unable to query KV invoice DB: %v",
err)
migrations := sqldb.GetMigrations()
return nil, nil, err
// If the user has not explicitly disabled the SQL invoice
// migration, attach the custom migration function to invoice
// migration (version 7). Even if this custom migration is
// disabled, the regular native SQL store migrations will still
// run. If the database version is already above this custom
// migration's version (7), it will be skipped permanently,
// regardless of the flag.
if !d.cfg.DB.SkipSQLInvoiceMigration {
migrationFn := func(tx *sqlc.Queries) error {
return invoices.MigrateInvoicesToSQL(
ctx, dbs.ChanStateDB.Backend,
dbs.ChanStateDB, tx,
invoiceMigrationBatchSize,
)
}
// Make sure we attach the custom migration function to
// the correct migration version.
for i := 0; i < len(migrations); i++ {
if migrations[i].Version != invoiceMigration {
continue
}
migrations[i].MigrationFn = migrationFn
}
}
if len(invoiceSlice.Invoices) > 0 {
// We need to apply all migrations to the native SQL store
// before we can use it.
err = dbs.NativeSQLStore.ApplyAllMigrations(ctx, migrations)
if err != nil {
cleanUp()
err := fmt.Errorf("found invoices in the KV invoice " +
"DB, migration to native SQL is not yet " +
"supported")
err = fmt.Errorf("faild to run migrations for the "+
"native SQL store: %w", err)
d.logger.Error(err)
return nil, nil, err
}
// With the DB ready and migrations applied, we can now create
// the base DB and transaction executor for the native SQL
// invoice store.
baseDB := dbs.NativeSQLStore.GetBaseDB()
executor := sqldb.NewTransactionExecutor(
dbs.NativeSQLStore,
func(tx *sql.Tx) invoices.SQLInvoiceQueries {
return dbs.NativeSQLStore.WithTx(tx)
baseDB, func(tx *sql.Tx) invoices.SQLInvoiceQueries {
return baseDB.WithTx(tx)
},
)
dbs.InvoiceDB = invoices.NewSQLStore(
sqlInvoiceDB := invoices.NewSQLStore(
executor, clock.NewDefaultClock(),
)
dbs.InvoiceDB = sqlInvoiceDB
} else {
dbs.InvoiceDB = dbs.ChanStateDB
}
@ -1129,7 +1158,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
if err != nil {
cleanUp()
err := fmt.Errorf("unable to open %s database: %w",
err = fmt.Errorf("unable to open %s database: %w",
lncfg.NSTowerClientDB, err)
d.logger.Error(err)
return nil, nil, err
@ -1144,7 +1173,7 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
if err != nil {
cleanUp()
err := fmt.Errorf("unable to open %s database: %w",
err = fmt.Errorf("unable to open %s database: %w",
lncfg.NSTowerServerDB, err)
d.logger.Error(err)
return nil, nil, err

View File

@ -264,6 +264,11 @@ The underlying functionality between those two options remain the same.
transactions can run at once, increasing efficiency. Includes several bugfixes
to allow this to work properly.
* [Migrate KV invoices to
SQL](https://github.com/lightningnetwork/lnd/pull/8831) as part of a larger
effort to support SQL databases natively in LND.
## Code Health
* A code refactor that [moves all the graph related DB code out of the
@ -292,6 +297,7 @@ The underlying functionality between those two options remain the same.
* Abdullahi Yunus
* Alex Akselrod
* Andras Banki-Horvath
* Animesh Bilthare
* Boris Nagaev
* Carla Kirk-Cohen

6
go.mod
View File

@ -138,7 +138,7 @@ require (
github.com/opencontainers/image-spec v1.0.2 // indirect
github.com/opencontainers/runc v1.1.12 // indirect
github.com/ory/dockertest/v3 v3.10.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pmezard/go-difflib v1.0.0
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/common v0.26.0 // indirect
github.com/prometheus/procfs v0.6.0 // indirect
@ -207,6 +207,10 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2
// allows us to specify that as an option.
replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display
// Temporary replace until https://github.com/lightningnetwork/lnd/pull/8831 is
// merged.
replace github.com/lightningnetwork/lnd/sqldb => ./sqldb
// If you change this please also update docs/INSTALL.md and GO_VERSION in
// Makefile (then run `make lint` to see where else it needs to be updated as
// well).

2
go.sum
View File

@ -464,8 +464,6 @@ github.com/lightningnetwork/lnd/kvdb v1.4.12 h1:Y0WY5Tbjyjn6eCYh068qkWur5oFtioJl
github.com/lightningnetwork/lnd/kvdb v1.4.12/go.mod h1:hx9buNcxsZpZwh8m1sjTQwy2SOeBoWWOZ3RnOQkMsxI=
github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI=
github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4=
github.com/lightningnetwork/lnd/sqldb v1.0.6 h1:LJdDSVdN33bVBIefsaJlPW9PDAm6GrXlyFucmzSJ3Ts=
github.com/lightningnetwork/lnd/sqldb v1.0.6/go.mod h1:OG09zL/PHPaBJefp4HsPz2YLUJ+zIQHbpgCtLnOx8I4=
github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM=
github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA=
github.com/lightningnetwork/lnd/tlv v1.3.0 h1:exS/KCPEgpOgviIttfiXAPaUqw2rHQrnUOpP7HPBPiY=

View File

@ -187,6 +187,11 @@ func (r InvoiceRef) Modifier() RefModifier {
return r.refModifier
}
// IsHashOnly returns true if the invoice ref only contains a payment hash.
func (r InvoiceRef) IsHashOnly() bool {
return r.payHash != nil && r.payAddr == nil && r.setID == nil
}
// String returns a human-readable representation of an InvoiceRef.
func (r InvoiceRef) String() string {
var ids []string

View File

@ -0,0 +1,203 @@
package invoices_test
import (
"context"
"database/sql"
"os"
"path"
"testing"
"time"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
invpkg "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/kvdb/sqlbase"
"github.com/lightningnetwork/lnd/kvdb/sqlite"
"github.com/lightningnetwork/lnd/lncfg"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
"github.com/stretchr/testify/require"
)
// TestMigrationWithChannelDB tests the migration of invoices from a bolt backed
// channel.db to a SQL database. Note that this test does not attempt to be a
// complete migration test for all invoice types but rather is added as a tool
// for developers and users to debug invoice migration issues with an actual
// channel.db file.
func TestMigrationWithChannelDB(t *testing.T) {
// First create a shared Postgres instance so we don't spawn a new
// docker container for each test.
pgFixture := sqldb.NewTestPgFixture(
t, sqldb.DefaultPostgresFixtureLifetime,
)
t.Cleanup(func() {
pgFixture.TearDown(t)
})
makeSQLDB := func(t *testing.T, sqlite bool) (*invpkg.SQLStore,
*sqldb.TransactionExecutor[*sqlc.Queries]) {
var db *sqldb.BaseDB
if sqlite {
db = sqldb.NewTestSqliteDB(t).BaseDB
} else {
db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
}
invoiceExecutor := sqldb.NewTransactionExecutor(
db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries {
return db.WithTx(tx)
},
)
genericExecutor := sqldb.NewTransactionExecutor(
db, func(tx *sql.Tx) *sqlc.Queries {
return db.WithTx(tx)
},
)
testClock := clock.NewTestClock(time.Unix(1, 0))
return invpkg.NewSQLStore(invoiceExecutor, testClock),
genericExecutor
}
migrationTest := func(t *testing.T, kvStore *channeldb.DB,
sqlite bool) {
sqlInvoiceStore, sqlStore := makeSQLDB(t, sqlite)
ctxb := context.Background()
const batchSize = 11
var opts sqldb.MigrationTxOptions
err := sqlStore.ExecTx(
ctxb, &opts, func(tx *sqlc.Queries) error {
return invpkg.MigrateInvoicesToSQL(
ctxb, kvStore.Backend, kvStore, tx,
batchSize,
)
}, func() {},
)
require.NoError(t, err)
// MigrateInvoices will check if the inserted invoice equals to
// the migrated one, but as a sanity check, we'll also fetch the
// invoices from the store and compare them to the original
// invoices.
query := invpkg.InvoiceQuery{
IndexOffset: 0,
// As a sanity check, fetch more invoices than we have
// to ensure that we did not add any extra invoices.
// Note that we don't really have a way to know the
// exact number of invoices in the bolt db without first
// iterating over all of them, but for test purposes
// constant should be enough.
NumMaxInvoices: 9999,
}
result1, err := kvStore.QueryInvoices(ctxb, query)
require.NoError(t, err)
numInvoices := len(result1.Invoices)
result2, err := sqlInvoiceStore.QueryInvoices(ctxb, query)
require.NoError(t, err)
require.Equal(t, numInvoices, len(result2.Invoices))
// Simply zero out the add index so we don't fail on that when
// comparing.
for i := 0; i < numInvoices; i++ {
result1.Invoices[i].AddIndex = 0
result2.Invoices[i].AddIndex = 0
// We need to override the timezone of the invoices as
// the provided DB vs the test runners local time zone
// might be different.
invpkg.OverrideInvoiceTimeZone(&result1.Invoices[i])
invpkg.OverrideInvoiceTimeZone(&result2.Invoices[i])
require.Equal(
t, result1.Invoices[i], result2.Invoices[i],
)
}
}
tests := []struct {
name string
dbPath string
}{
{
"empty",
t.TempDir(),
},
{
"testdata",
"testdata",
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
var kvStore *channeldb.DB
// First check if we have a channel.sqlite file in the
// testdata directory. If we do, we'll use that as the
// channel db for the migration test.
chanDBPath := path.Join(
test.dbPath, lncfg.SqliteChannelDBName,
)
// Just some sane defaults for the sqlite config.
const (
timeout = 5 * time.Second
maxConns = 50
)
sqliteConfig := &sqlite.Config{
Timeout: timeout,
BusyTimeout: timeout,
MaxConnections: maxConns,
}
if fileExists(chanDBPath) {
sqlbase.Init(maxConns)
sqliteBackend, err := kvdb.Open(
kvdb.SqliteBackendName,
context.Background(),
sqliteConfig, test.dbPath,
lncfg.SqliteChannelDBName,
lncfg.NSChannelDB,
)
require.NoError(t, err)
kvStore, err = channeldb.CreateWithBackend(
sqliteBackend,
)
require.NoError(t, err)
} else {
kvStore = channeldb.OpenForTesting(
t, test.dbPath,
)
}
t.Run("Postgres", func(t *testing.T) {
migrationTest(t, kvStore, false)
})
t.Run("SQLite", func(t *testing.T) {
migrationTest(t, kvStore, true)
})
})
}
}
func fileExists(filename string) bool {
info, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return !info.IsDir()
}

558
invoices/sql_migration.go Normal file
View File

@ -0,0 +1,558 @@
package invoices
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"reflect"
"strconv"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
"github.com/pmezard/go-difflib/difflib"
)
var (
// invoiceBucket is the name of the bucket within the database that
// stores all data related to invoices no matter their final state.
// Within the invoice bucket, each invoice is keyed by its invoice ID
// which is a monotonically increasing uint32.
invoiceBucket = []byte("invoices")
// invoiceIndexBucket is the name of the sub-bucket within the
// invoiceBucket which indexes all invoices by their payment hash. The
// payment hash is the sha256 of the invoice's payment preimage. This
// index is used to detect duplicates, and also to provide a fast path
// for looking up incoming HTLCs to determine if we're able to settle
// them fully.
//
// maps: payHash => invoiceKey
invoiceIndexBucket = []byte("paymenthashes")
// numInvoicesKey is the name of key which houses the auto-incrementing
// invoice ID which is essentially used as a primary key. With each
// invoice inserted, the primary key is incremented by one. This key is
// stored within the invoiceIndexBucket. Within the invoiceBucket
// invoices are uniquely identified by the invoice ID.
numInvoicesKey = []byte("nik")
// addIndexBucket is an index bucket that we'll use to create a
// monotonically increasing set of add indexes. Each time we add a new
// invoice, this sequence number will be incremented and then populated
// within the new invoice.
//
// In addition to this sequence number, we map:
//
// addIndexNo => invoiceKey
addIndexBucket = []byte("invoice-add-index")
// ErrMigrationMismatch is returned when the migrated invoice does not
// match the original invoice.
ErrMigrationMismatch = fmt.Errorf("migrated invoice does not match " +
"original invoice")
)
// createInvoiceHashIndex generates a hash index that contains payment hashes
// for each invoice in the database. Retrieving the payment hash for certain
// invoices, such as those created for spontaneous AMP payments, can be
// challenging because the hash is not directly derivable from the invoice's
// parameters and is stored separately in the `paymenthashes` bucket. This
// bucket maps payment hashes to invoice keys, but for migration purposes, we
// need the ability to query in the reverse direction. This function establishes
// a new index in the SQL database that maps each invoice key to its
// corresponding payment hash.
func createInvoiceHashIndex(ctx context.Context, db kvdb.Backend,
tx *sqlc.Queries) error {
return db.View(func(kvTx kvdb.RTx) error {
invoices := kvTx.ReadBucket(invoiceBucket)
if invoices == nil {
return ErrNoInvoicesCreated
}
invoiceIndex := invoices.NestedReadBucket(
invoiceIndexBucket,
)
if invoiceIndex == nil {
return ErrNoInvoicesCreated
}
addIndex := invoices.NestedReadBucket(addIndexBucket)
if addIndex == nil {
return ErrNoInvoicesCreated
}
// First, iterate over all elements in the add index bucket and
// insert the add index value for the corresponding invoice key
// in the payment_hashes table.
err := addIndex.ForEach(func(k, v []byte) error {
// The key is the add index, and the value is
// the invoice key.
addIndexNo := binary.BigEndian.Uint64(k)
invoiceKey := binary.BigEndian.Uint32(v)
return tx.InsertKVInvoiceKeyAndAddIndex(ctx,
sqlc.InsertKVInvoiceKeyAndAddIndexParams{
ID: int64(invoiceKey),
AddIndex: int64(addIndexNo),
},
)
})
if err != nil {
return err
}
// Next, iterate over all hashes in the invoice index bucket and
// set the hash to the corresponding the invoice key in the
// payment_hashes table.
return invoiceIndex.ForEach(func(k, v []byte) error {
// Skip the special numInvoicesKey as that does
// not point to a valid invoice.
if bytes.Equal(k, numInvoicesKey) {
return nil
}
// The key is the payment hash, and the value
// is the invoice key.
if len(k) != lntypes.HashSize {
return fmt.Errorf("invalid payment "+
"hash length: expected %v, "+
"got %v", lntypes.HashSize,
len(k))
}
invoiceKey := binary.BigEndian.Uint32(v)
return tx.SetKVInvoicePaymentHash(ctx,
sqlc.SetKVInvoicePaymentHashParams{
ID: int64(invoiceKey),
Hash: k,
},
)
})
}, func() {})
}
// toInsertMigratedInvoiceParams creates the parameters for inserting a migrated
// invoice into the SQL database. The parameters are derived from the original
// invoice insert parameters.
func toInsertMigratedInvoiceParams(
params sqlc.InsertInvoiceParams) sqlc.InsertMigratedInvoiceParams {
return sqlc.InsertMigratedInvoiceParams{
Hash: params.Hash,
Preimage: params.Preimage,
Memo: params.Memo,
AmountMsat: params.AmountMsat,
CltvDelta: params.CltvDelta,
Expiry: params.Expiry,
PaymentAddr: params.PaymentAddr,
PaymentRequest: params.PaymentRequest,
PaymentRequestHash: params.PaymentRequestHash,
State: params.State,
AmountPaidMsat: params.AmountPaidMsat,
IsAmp: params.IsAmp,
IsHodl: params.IsHodl,
IsKeysend: params.IsKeysend,
CreatedAt: params.CreatedAt,
}
}
// MigrateSingleInvoice migrates a single invoice to the new SQL schema. Note
// that perfect equality between the old and new schemas is not achievable, as
// the invoice's add index cannot be mapped directly to its ID due to SQLs
// auto-incrementing primary key. The ID returned from the insert will instead
// serve as the add index in the new schema.
func MigrateSingleInvoice(ctx context.Context, tx SQLInvoiceQueries,
invoice *Invoice, paymentHash lntypes.Hash) error {
insertInvoiceParams, err := makeInsertInvoiceParams(
invoice, paymentHash,
)
if err != nil {
return err
}
// Convert the insert invoice parameters to the migrated invoice insert
// parameters.
insertMigratedInvoiceParams := toInsertMigratedInvoiceParams(
insertInvoiceParams,
)
// If the invoice is settled, we'll also set the timestamp and the index
// at which it was settled.
if invoice.State == ContractSettled {
if invoice.SettleIndex == 0 {
return fmt.Errorf("settled invoice %s missing settle "+
"index", paymentHash)
}
if invoice.SettleDate.IsZero() {
return fmt.Errorf("settled invoice %s missing settle "+
"date", paymentHash)
}
insertMigratedInvoiceParams.SettleIndex = sqldb.SQLInt64(
invoice.SettleIndex,
)
insertMigratedInvoiceParams.SettledAt = sqldb.SQLTime(
invoice.SettleDate.UTC(),
)
}
// First we need to insert the invoice itself so we can use the "add
// index" which in this case is the auto incrementing primary key that
// is returned from the insert.
invoiceID, err := tx.InsertMigratedInvoice(
ctx, insertMigratedInvoiceParams,
)
if err != nil {
return fmt.Errorf("unable to insert invoice: %w", err)
}
// Insert the invoice's features.
for feature := range invoice.Terms.Features.Features() {
params := sqlc.InsertInvoiceFeatureParams{
InvoiceID: invoiceID,
Feature: int32(feature),
}
err := tx.InsertInvoiceFeature(ctx, params)
if err != nil {
return fmt.Errorf("unable to insert invoice "+
"feature(%v): %w", feature, err)
}
}
sqlHtlcIDs := make(map[models.CircuitKey]int64)
// Now insert the HTLCs of the invoice. We'll also keep track of the SQL
// ID of each HTLC so we can use it when inserting the AMP sub invoices.
for circuitKey, htlc := range invoice.Htlcs {
htlcParams := sqlc.InsertInvoiceHTLCParams{
HtlcID: int64(circuitKey.HtlcID),
ChanID: strconv.FormatUint(
circuitKey.ChanID.ToUint64(), 10,
),
AmountMsat: int64(htlc.Amt),
AcceptHeight: int32(htlc.AcceptHeight),
AcceptTime: htlc.AcceptTime.UTC(),
ExpiryHeight: int32(htlc.Expiry),
State: int16(htlc.State),
InvoiceID: invoiceID,
}
// Leave the MPP amount as NULL if the MPP total amount is zero.
if htlc.MppTotalAmt != 0 {
htlcParams.TotalMppMsat = sqldb.SQLInt64(
int64(htlc.MppTotalAmt),
)
}
// Leave the resolve time as NULL if the HTLC is not resolved.
if !htlc.ResolveTime.IsZero() {
htlcParams.ResolveTime = sqldb.SQLTime(
htlc.ResolveTime.UTC(),
)
}
sqlID, err := tx.InsertInvoiceHTLC(ctx, htlcParams)
if err != nil {
return fmt.Errorf("unable to insert invoice htlc: %w",
err)
}
sqlHtlcIDs[circuitKey] = sqlID
// Store custom records.
for key, value := range htlc.CustomRecords {
err = tx.InsertInvoiceHTLCCustomRecord(
ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
Key: int64(key),
Value: value,
HtlcID: sqlID,
},
)
if err != nil {
return err
}
}
}
if !invoice.IsAMP() {
return nil
}
for setID, ampState := range invoice.AMPState {
// Find the earliest HTLC of the AMP invoice, which will
// be used as the creation date of this sub invoice.
var createdAt time.Time
for circuitKey := range ampState.InvoiceKeys {
htlc := invoice.Htlcs[circuitKey]
if createdAt.IsZero() {
createdAt = htlc.AcceptTime.UTC()
continue
}
if createdAt.After(htlc.AcceptTime) {
createdAt = htlc.AcceptTime.UTC()
}
}
params := sqlc.InsertAMPSubInvoiceParams{
SetID: setID[:],
State: int16(ampState.State),
CreatedAt: createdAt,
InvoiceID: invoiceID,
}
if ampState.SettleIndex != 0 {
if ampState.SettleDate.IsZero() {
return fmt.Errorf("settled AMP sub invoice %x "+
"missing settle date", setID)
}
params.SettledAt = sqldb.SQLTime(
ampState.SettleDate.UTC(),
)
params.SettleIndex = sqldb.SQLInt64(
ampState.SettleIndex,
)
}
err := tx.InsertAMPSubInvoice(ctx, params)
if err != nil {
return fmt.Errorf("unable to insert AMP sub invoice: "+
"%w", err)
}
// Now we can add the AMP HTLCs to the database.
for circuitKey := range ampState.InvoiceKeys {
htlc := invoice.Htlcs[circuitKey]
rootShare := htlc.AMP.Record.RootShare()
sqlHtlcID, ok := sqlHtlcIDs[circuitKey]
if !ok {
return fmt.Errorf("missing htlc for AMP htlc: "+
"%v", circuitKey)
}
params := sqlc.InsertAMPSubInvoiceHTLCParams{
InvoiceID: invoiceID,
SetID: setID[:],
HtlcID: sqlHtlcID,
RootShare: rootShare[:],
ChildIndex: int64(htlc.AMP.Record.ChildIndex()),
Hash: htlc.AMP.Hash[:],
}
if htlc.AMP.Preimage != nil {
params.Preimage = htlc.AMP.Preimage[:]
}
err = tx.InsertAMPSubInvoiceHTLC(ctx, params)
if err != nil {
return fmt.Errorf("unable to insert AMP sub "+
"invoice: %w", err)
}
}
}
return nil
}
// OverrideInvoiceTimeZone overrides the time zone of the invoice to the local
// time zone and chops off the nanosecond part for comparison. This is needed
// because KV database stores times as-is which as an unwanted side effect would
// fail migration due to time comparison expecting both the original and
// migrated invoices to be in the same local time zone and in microsecond
// precision. Note that PostgreSQL stores times in microsecond precision while
// SQLite can store times in nanosecond precision if using TEXT storage class.
func OverrideInvoiceTimeZone(invoice *Invoice) {
fixTime := func(t time.Time) time.Time {
return t.In(time.Local).Truncate(time.Microsecond)
}
invoice.CreationDate = fixTime(invoice.CreationDate)
if !invoice.SettleDate.IsZero() {
invoice.SettleDate = fixTime(invoice.SettleDate)
}
if invoice.IsAMP() {
for setID, ampState := range invoice.AMPState {
if ampState.SettleDate.IsZero() {
continue
}
ampState.SettleDate = fixTime(ampState.SettleDate)
invoice.AMPState[setID] = ampState
}
}
for _, htlc := range invoice.Htlcs {
if !htlc.AcceptTime.IsZero() {
htlc.AcceptTime = fixTime(htlc.AcceptTime)
}
if !htlc.ResolveTime.IsZero() {
htlc.ResolveTime = fixTime(htlc.ResolveTime)
}
}
}
// MigrateInvoicesToSQL runs the migration of all invoices from the KV database
// to the SQL database. The migration is done in a single transaction to ensure
// that all invoices are migrated or none at all. This function can be run
// multiple times without causing any issues as it will check if the migration
// has already been performed.
func MigrateInvoicesToSQL(ctx context.Context, db kvdb.Backend,
kvStore InvoiceDB, tx *sqlc.Queries, batchSize int) error {
log.Infof("Starting migration of invoices from KV to SQL")
offset := uint64(0)
t0 := time.Now()
// Create the hash index which we will use to look up invoice
// payment hashes by their add index during migration.
err := createInvoiceHashIndex(ctx, db, tx)
if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
log.Errorf("Unable to create invoice hash index: %v",
err)
return err
}
log.Debugf("Created SQL invoice hash index in %v", time.Since(t0))
total := 0
// Now we can start migrating the invoices. We'll do this in
// batches to reduce memory usage.
for {
t0 = time.Now()
query := InvoiceQuery{
IndexOffset: offset,
NumMaxInvoices: uint64(batchSize),
}
queryResult, err := kvStore.QueryInvoices(ctx, query)
if err != nil && !errors.Is(err, ErrNoInvoicesCreated) {
return fmt.Errorf("unable to query invoices: "+
"%w", err)
}
if len(queryResult.Invoices) == 0 {
log.Infof("All invoices migrated")
break
}
err = migrateInvoices(ctx, tx, queryResult.Invoices)
if err != nil {
return err
}
offset = queryResult.LastIndexOffset
total += len(queryResult.Invoices)
log.Debugf("Migrated %d KV invoices to SQL in %v\n", total,
time.Since(t0))
}
// Clean up the hash index as it's no longer needed.
err = tx.ClearKVInvoiceHashIndex(ctx)
if err != nil {
return fmt.Errorf("unable to clear invoice hash "+
"index: %w", err)
}
log.Infof("Migration of %d invoices from KV to SQL completed", total)
return nil
}
func migrateInvoices(ctx context.Context, tx *sqlc.Queries,
invoices []Invoice) error {
for i, invoice := range invoices {
var paymentHash lntypes.Hash
if invoice.Terms.PaymentPreimage != nil {
paymentHash = invoice.Terms.PaymentPreimage.Hash()
} else {
paymentHashBytes, err :=
tx.GetKVInvoicePaymentHashByAddIndex(
ctx, int64(invoice.AddIndex),
)
if err != nil {
// This would be an unexpected inconsistency
// in the kv database. We can't do much here
// so we'll notify the user and continue.
log.Warnf("Cannot migrate invoice, unable to "+
"fetch payment hash (add_index=%v): %v",
invoice.AddIndex, err)
continue
}
copy(paymentHash[:], paymentHashBytes)
}
err := MigrateSingleInvoice(ctx, tx, &invoices[i], paymentHash)
if err != nil {
return fmt.Errorf("unable to migrate invoice(%v): %w",
paymentHash, err)
}
migratedInvoice, err := fetchInvoice(
ctx, tx, InvoiceRefByHash(paymentHash),
)
if err != nil {
return fmt.Errorf("unable to fetch migrated "+
"invoice(%v): %w", paymentHash, err)
}
// Override the time zone for comparison. Note that we need to
// override both invoices as the original invoice is coming from
// KV database, it was stored as a binary serialized Go
// time.Time value which has nanosecond precision but might have
// been created in a different time zone. The migrated invoice
// is stored in SQL in UTC and selected in the local time zone,
// however in PostgreSQL it has microsecond precision while in
// SQLite it has nanosecond precision if using TEXT storage
// class.
OverrideInvoiceTimeZone(&invoice)
OverrideInvoiceTimeZone(migratedInvoice)
// Override the add index before checking for equality.
migratedInvoice.AddIndex = invoice.AddIndex
if !reflect.DeepEqual(invoice, *migratedInvoice) {
diff := difflib.UnifiedDiff{
A: difflib.SplitLines(
spew.Sdump(invoice),
),
B: difflib.SplitLines(
spew.Sdump(migratedInvoice),
),
FromFile: "Expected",
FromDate: "",
ToFile: "Actual",
ToDate: "",
Context: 3,
}
diffText, _ := difflib.GetUnifiedDiffString(diff)
return fmt.Errorf("%w: %v.\n%v", ErrMigrationMismatch,
paymentHash, diffText)
}
}
return nil
}

View File

@ -0,0 +1,421 @@
package invoices
import (
"context"
crand "crypto/rand"
"database/sql"
"math/rand"
"sync/atomic"
"testing"
"time"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/stretchr/testify/require"
"pgregory.net/rapid"
)
var (
// testHtlcIDSequence is a global counter for generating unique HTLC
// IDs.
testHtlcIDSequence uint64
)
// randomString generates a random string of a given length using rapid.
func randomStringRapid(t *rapid.T, length int) string {
// Define the character set for the string.
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" //nolint:ll
// Generate a string by selecting random characters from the charset.
runes := make([]rune, length)
for i := range runes {
// Draw a random index and use it to select a character from the
// charset.
index := rapid.IntRange(0, len(charset)-1).Draw(t, "charIndex")
runes[i] = rune(charset[index])
}
return string(runes)
}
// randTimeBetween generates a random time between min and max.
func randTimeBetween(min, max time.Time) time.Time {
var timeZones = []*time.Location{
time.UTC,
time.FixedZone("EST", -5*3600),
time.FixedZone("MST", -7*3600),
time.FixedZone("PST", -8*3600),
time.FixedZone("CEST", 2*3600),
}
// Ensure max is after min
if max.Before(min) {
min, max = max, min
}
// Calculate the range in nanoseconds
duration := max.Sub(min)
randDuration := time.Duration(rand.Int63n(duration.Nanoseconds()))
// Generate the random time
randomTime := min.Add(randDuration)
// Assign a random time zone
randomTimeZone := timeZones[rand.Intn(len(timeZones))]
// Return the time in the random time zone
return randomTime.In(randomTimeZone)
}
// randTime generates a random time between 2009 and 2140.
func randTime() time.Time {
min := time.Date(2009, 1, 3, 0, 0, 0, 0, time.UTC)
max := time.Date(2140, 1, 1, 0, 0, 0, 1000, time.UTC)
return randTimeBetween(min, max)
}
func randInvoiceTime(invoice *Invoice) time.Time {
return randTimeBetween(
invoice.CreationDate,
invoice.CreationDate.Add(invoice.Terms.Expiry),
)
}
// randHTLCRapid generates a random HTLC for an invoice using rapid to randomize
// its parameters.
func randHTLCRapid(t *rapid.T, invoice *Invoice, amt lnwire.MilliSatoshi) (
models.CircuitKey, *InvoiceHTLC) {
htlc := &InvoiceHTLC{
Amt: amt,
AcceptHeight: rapid.Uint32Range(1, 999).Draw(t, "AcceptHeight"),
AcceptTime: randInvoiceTime(invoice),
Expiry: rapid.Uint32Range(1, 999).Draw(t, "Expiry"),
}
// Set MPP total amount if MPP feature is enabled in the invoice.
if invoice.Terms.Features.HasFeature(lnwire.MPPRequired) {
htlc.MppTotalAmt = invoice.Terms.Value
}
// Set the HTLC state and resolve time based on the invoice state.
switch invoice.State {
case ContractSettled:
htlc.State = HtlcStateSettled
htlc.ResolveTime = randInvoiceTime(invoice)
case ContractCanceled:
htlc.State = HtlcStateCanceled
htlc.ResolveTime = randInvoiceTime(invoice)
case ContractAccepted:
htlc.State = HtlcStateAccepted
}
// Add randomized custom records to the HTLC.
htlc.CustomRecords = make(record.CustomSet)
numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords")
for i := 0; i < numRecords; i++ {
key := rapid.Uint64Range(
record.CustomTypeStart, 1000+record.CustomTypeStart,
).Draw(t, "customRecordKey")
value := []byte(randomStringRapid(t, 10))
htlc.CustomRecords[key] = value
}
// Generate a unique HTLC ID and assign it to a channel ID.
htlcID := atomic.AddUint64(&testHtlcIDSequence, 1)
randChanID := lnwire.NewShortChanIDFromInt(htlcID % 5)
circuitKey := models.CircuitKey{
ChanID: randChanID,
HtlcID: htlcID,
}
return circuitKey, htlc
}
// generateInvoiceHTLCsRapid generates all HTLCs for an invoice, including AMP
// HTLCs if applicable, using rapid for randomization of HTLC count and
// distribution.
func generateInvoiceHTLCsRapid(t *rapid.T, invoice *Invoice) {
mpp := invoice.Terms.Features.HasFeature(lnwire.MPPRequired)
// Use rapid to determine the number of HTLCs based on invoice state and
// MPP feature.
numHTLCs := 1
if invoice.State == ContractOpen {
numHTLCs = 0
} else if mpp {
numHTLCs = rapid.IntRange(1, 10).Draw(t, "numHTLCs")
}
total := invoice.Terms.Value
// Distribute the total amount across the HTLCs, adding any remainder to
// the last HTLC.
if numHTLCs > 0 {
amt := total / lnwire.MilliSatoshi(numHTLCs)
remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
for i := 0; i < numHTLCs; i++ {
if i == numHTLCs-1 {
// Add remainder to the last HTLC.
amt += remainder
}
// Generate an HTLC with a random circuit key and add it
// to the invoice.
circuitKey, htlc := randHTLCRapid(t, invoice, amt)
invoice.Htlcs[circuitKey] = htlc
}
}
}
// generateAMPHtlcsRapid generates AMP HTLCs for an invoice using rapid to
// randomize various parameters of the HTLCs in the AMP set.
func generateAMPHtlcsRapid(t *rapid.T, invoice *Invoice) {
// Randomly determine the number of AMP sets (1 to 5).
numSetIDs := rapid.IntRange(1, 5).Draw(t, "numSetIDs")
settledIdx := uint64(1)
for i := 0; i < numSetIDs; i++ {
var setID SetID
_, err := crand.Read(setID[:])
require.NoError(t, err)
// Determine the number of HTLCs in this set (1 to 5).
numHTLCs := rapid.IntRange(1, 5).Draw(t, "numHTLCs")
total := invoice.Terms.Value
invoiceKeys := make(map[CircuitKey]struct{})
// Calculate the amount per HTLC and account for remainder in
// the final HTLC.
amt := total / lnwire.MilliSatoshi(numHTLCs)
remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
var htlcState HtlcState
for j := 0; j < numHTLCs; j++ {
if j == numHTLCs-1 {
amt += remainder
}
// Generate HTLC with randomized parameters.
circuitKey, htlc := randHTLCRapid(t, invoice, amt)
htlcState = htlc.State
var (
rootShare, hash [32]byte
preimage lntypes.Preimage
)
// Randomize AMP data fields.
_, err := crand.Read(rootShare[:])
require.NoError(t, err)
_, err = crand.Read(hash[:])
require.NoError(t, err)
_, err = crand.Read(preimage[:])
require.NoError(t, err)
record := record.NewAMP(rootShare, setID, uint32(j))
htlc.AMP = &InvoiceHtlcAMPData{
Record: *record,
Hash: hash,
Preimage: &preimage,
}
invoice.Htlcs[circuitKey] = htlc
invoiceKeys[circuitKey] = struct{}{}
}
ampState := InvoiceStateAMP{
State: htlcState,
InvoiceKeys: invoiceKeys,
}
if htlcState == HtlcStateSettled {
ampState.SettleIndex = settledIdx
ampState.SettleDate = randInvoiceTime(invoice)
settledIdx++
}
// Set the total amount paid if the AMP set is not canceled.
if htlcState != HtlcStateCanceled {
ampState.AmtPaid = invoice.Terms.Value
}
invoice.AMPState[setID] = ampState
}
}
// TestMigrateSingleInvoiceRapid tests the migration of single invoices with
// random data variations using rapid. This test generates a random invoice
// configuration and ensures successful migration.
//
// NOTE: This test may need to be changed if the Invoice or any of the related
// types are modified.
func TestMigrateSingleInvoiceRapid(t *testing.T) {
// Create a shared Postgres instance for efficient testing.
pgFixture := sqldb.NewTestPgFixture(
t, sqldb.DefaultPostgresFixtureLifetime,
)
t.Cleanup(func() {
pgFixture.TearDown(t)
})
makeSQLDB := func(t *testing.T, sqlite bool) *SQLStore {
var db *sqldb.BaseDB
if sqlite {
db = sqldb.NewTestSqliteDB(t).BaseDB
} else {
db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
}
executor := sqldb.NewTransactionExecutor(
db, func(tx *sql.Tx) SQLInvoiceQueries {
return db.WithTx(tx)
},
)
testClock := clock.NewTestClock(time.Unix(1, 0))
return NewSQLStore(executor, testClock)
}
// Define property-based test using rapid.
rapid.Check(t, func(rt *rapid.T) {
// Randomized feature flags for MPP and AMP.
mpp := rapid.Bool().Draw(rt, "mpp")
amp := rapid.Bool().Draw(rt, "amp")
for _, sqlite := range []bool{true, false} {
store := makeSQLDB(t, sqlite)
testMigrateSingleInvoiceRapid(rt, store, mpp, amp)
}
})
}
// testMigrateSingleInvoiceRapid is the primary function for the migration of a
// single invoice with random data in a rapid-based test setup.
func testMigrateSingleInvoiceRapid(t *rapid.T, store *SQLStore, mpp bool,
amp bool) {
ctxb := context.Background()
invoices := make(map[lntypes.Hash]*Invoice)
for i := 0; i < 100; i++ {
invoice := generateTestInvoiceRapid(t, mpp, amp)
var hash lntypes.Hash
_, err := crand.Read(hash[:])
require.NoError(t, err)
invoices[hash] = invoice
}
var ops SQLInvoiceQueriesTxOptions
err := store.db.ExecTx(ctxb, &ops, func(tx SQLInvoiceQueries) error {
for hash, invoice := range invoices {
err := MigrateSingleInvoice(ctxb, tx, invoice, hash)
require.NoError(t, err)
}
return nil
}, func() {})
require.NoError(t, err)
// Fetch and compare each migrated invoice from the store with the
// original.
for hash, invoice := range invoices {
sqlInvoice, err := store.LookupInvoice(
ctxb, InvoiceRefByHash(hash),
)
require.NoError(t, err)
invoice.AddIndex = sqlInvoice.AddIndex
OverrideInvoiceTimeZone(invoice)
OverrideInvoiceTimeZone(&sqlInvoice)
require.Equal(t, *invoice, sqlInvoice)
}
}
// generateTestInvoiceRapid generates a random invoice with variations based on
// mpp and amp flags.
func generateTestInvoiceRapid(t *rapid.T, mpp bool, amp bool) *Invoice {
var preimage lntypes.Preimage
_, err := crand.Read(preimage[:])
require.NoError(t, err)
terms := ContractTerm{
FinalCltvDelta: rapid.Int32Range(1, 1000).Draw(
t, "FinalCltvDelta",
),
Expiry: time.Duration(
rapid.IntRange(1, 4444).Draw(t, "Expiry"),
) * time.Minute,
PaymentPreimage: &preimage,
Value: lnwire.MilliSatoshi(
rapid.Int64Range(1, 9999999).Draw(t, "Value"),
),
PaymentAddr: [32]byte{},
Features: lnwire.EmptyFeatureVector(),
}
if amp {
terms.Features.Set(lnwire.AMPRequired)
} else if mpp {
terms.Features.Set(lnwire.MPPRequired)
}
created := randTime()
const maxContractState = 3
state := ContractState(
rapid.IntRange(0, maxContractState).Draw(t, "ContractState"),
)
var (
settled time.Time
settleIndex uint64
)
if state == ContractSettled {
settled = randTimeBetween(created, created.Add(terms.Expiry))
settleIndex = rapid.Uint64Range(1, 999).Draw(t, "SettleIndex")
}
invoice := &Invoice{
Memo: []byte(randomStringRapid(t, 10)),
PaymentRequest: []byte(
randomStringRapid(t, MaxPaymentRequestSize),
),
CreationDate: created,
SettleDate: settled,
Terms: terms,
AddIndex: 0,
SettleIndex: settleIndex,
State: state,
AMPState: make(map[SetID]InvoiceStateAMP),
HodlInvoice: rapid.Bool().Draw(t, "HodlInvoice"),
}
invoice.Htlcs = make(map[models.CircuitKey]*InvoiceHTLC)
if invoice.IsAMP() {
generateAMPHtlcsRapid(t, invoice)
} else {
generateInvoiceHTLCsRapid(t, invoice)
}
for _, htlc := range invoice.Htlcs {
if htlc.State == HtlcStateSettled {
invoice.AmtPaid += htlc.Amt
}
}
return invoice
}

View File

@ -32,6 +32,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
error)
// TODO(bhandras): remove this once migrations have been separated out.
InsertMigratedInvoice(ctx context.Context,
arg sqlc.InsertMigratedInvoiceParams) (int64, error)
InsertInvoiceFeature(ctx context.Context,
arg sqlc.InsertInvoiceFeatureParams) error
@ -47,6 +51,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
GetInvoice(ctx context.Context,
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice,
error)
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
error)
@ -79,6 +86,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
UpsertAMPSubInvoice(ctx context.Context,
arg sqlc.UpsertAMPSubInvoiceParams) (sql.Result, error)
// TODO(bhandras): remove this once migrations have been separated out.
InsertAMPSubInvoice(ctx context.Context,
arg sqlc.InsertAMPSubInvoiceParams) error
UpdateAMPSubInvoiceState(ctx context.Context,
arg sqlc.UpdateAMPSubInvoiceStateParams) error
@ -119,6 +130,19 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
OnAMPSubInvoiceSettled(ctx context.Context,
arg sqlc.OnAMPSubInvoiceSettledParams) error
// Migration specific methods.
// TODO(bhandras): remove this once migrations have been separated out.
InsertKVInvoiceKeyAndAddIndex(ctx context.Context,
arg sqlc.InsertKVInvoiceKeyAndAddIndexParams) error
SetKVInvoicePaymentHash(ctx context.Context,
arg sqlc.SetKVInvoicePaymentHashParams) error
GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) (
[]byte, error)
ClearKVInvoiceHashIndex(ctx context.Context) error
}
var _ InvoiceDB = (*SQLStore)(nil)
@ -200,6 +224,66 @@ func NewSQLStore(db BatchedSQLInvoiceQueries,
}
}
func makeInsertInvoiceParams(invoice *Invoice, paymentHash lntypes.Hash) (
sqlc.InsertInvoiceParams, error) {
// Precompute the payment request hash so we can use it in the query.
var paymentRequestHash []byte
if len(invoice.PaymentRequest) > 0 {
h := sha256.New()
h.Write(invoice.PaymentRequest)
paymentRequestHash = h.Sum(nil)
}
params := sqlc.InsertInvoiceParams{
Hash: paymentHash[:],
AmountMsat: int64(invoice.Terms.Value),
CltvDelta: sqldb.SQLInt32(
invoice.Terms.FinalCltvDelta,
),
Expiry: int32(invoice.Terms.Expiry.Seconds()),
// Note: keysend invoices don't have a payment request.
PaymentRequest: sqldb.SQLStr(string(
invoice.PaymentRequest),
),
PaymentRequestHash: paymentRequestHash,
State: int16(invoice.State),
AmountPaidMsat: int64(invoice.AmtPaid),
IsAmp: invoice.IsAMP(),
IsHodl: invoice.HodlInvoice,
IsKeysend: invoice.IsKeysend(),
CreatedAt: invoice.CreationDate.UTC(),
}
if invoice.Memo != nil {
// Store the memo as a nullable string in the database. Note
// that for compatibility reasons, we store the value as a valid
// string even if it's empty.
params.Memo = sql.NullString{
String: string(invoice.Memo),
Valid: true,
}
}
// Some invoices may not have a preimage, like in the case of HODL
// invoices.
if invoice.Terms.PaymentPreimage != nil {
preimage := *invoice.Terms.PaymentPreimage
if preimage == UnknownPreimage {
return sqlc.InsertInvoiceParams{},
errors.New("cannot use all-zeroes preimage")
}
params.Preimage = preimage[:]
}
// Some non MPP payments may have the default (invalid) value.
if invoice.Terms.PaymentAddr != BlankPayAddr {
params.PaymentAddr = invoice.Terms.PaymentAddr[:]
}
return params, nil
}
// AddInvoice inserts the targeted invoice into the database. If the invoice has
// *any* payment hashes which already exists within the database, then the
// insertion will be aborted and rejected due to the strict policy banning any
@ -220,55 +304,16 @@ func (i *SQLStore) AddInvoice(ctx context.Context,
invoiceID int64
)
// Precompute the payment request hash so we can use it in the query.
var paymentRequestHash []byte
if len(newInvoice.PaymentRequest) > 0 {
h := sha256.New()
h.Write(newInvoice.PaymentRequest)
paymentRequestHash = h.Sum(nil)
insertInvoiceParams, err := makeInsertInvoiceParams(
newInvoice, paymentHash,
)
if err != nil {
return 0, err
}
err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
params := sqlc.InsertInvoiceParams{
Hash: paymentHash[:],
Memo: sqldb.SQLStr(string(newInvoice.Memo)),
AmountMsat: int64(newInvoice.Terms.Value),
// Note: BOLT12 invoices don't have a final cltv delta.
CltvDelta: sqldb.SQLInt32(
newInvoice.Terms.FinalCltvDelta,
),
Expiry: int32(newInvoice.Terms.Expiry.Seconds()),
// Note: keysend invoices don't have a payment request.
PaymentRequest: sqldb.SQLStr(string(
newInvoice.PaymentRequest),
),
PaymentRequestHash: paymentRequestHash,
State: int16(newInvoice.State),
AmountPaidMsat: int64(newInvoice.AmtPaid),
IsAmp: newInvoice.IsAMP(),
IsHodl: newInvoice.HodlInvoice,
IsKeysend: newInvoice.IsKeysend(),
CreatedAt: newInvoice.CreationDate.UTC(),
}
// Some invoices may not have a preimage, like in the case of
// HODL invoices.
if newInvoice.Terms.PaymentPreimage != nil {
preimage := *newInvoice.Terms.PaymentPreimage
if preimage == UnknownPreimage {
return errors.New("cannot use all-zeroes " +
"preimage")
}
params.Preimage = preimage[:]
}
// Some non MPP payments may have the default (invalid) value.
if newInvoice.Terms.PaymentAddr != BlankPayAddr {
params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
}
err = i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
var err error
invoiceID, err = db.InsertInvoice(ctx, params)
invoiceID, err = db.InsertInvoice(ctx, insertInvoiceParams)
if err != nil {
return fmt.Errorf("unable to insert invoice: %w", err)
}
@ -312,22 +357,31 @@ func (i *SQLStore) AddInvoice(ctx context.Context,
return newInvoice.AddIndex, nil
}
// fetchInvoice fetches the common invoice data and the AMP state for the
// invoice with the given reference.
func (i *SQLStore) fetchInvoice(ctx context.Context,
db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) {
// getInvoiceByRef fetches the invoice with the given reference. The reference
// may be a payment hash, a payment address, or a set ID for an AMP sub invoice.
func getInvoiceByRef(ctx context.Context,
db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) {
// If the reference is empty, we can't look up the invoice.
if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
return nil, ErrInvoiceNotFound
return sqlc.Invoice{}, ErrInvoiceNotFound
}
var (
invoice *Invoice
params sqlc.GetInvoiceParams
)
// If the reference is a hash only, we can look up the invoice directly
// by the payment hash which is faster.
if ref.IsHashOnly() {
invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:])
if errors.Is(err, sql.ErrNoRows) {
return sqlc.Invoice{}, ErrInvoiceNotFound
}
return invoice, err
}
// Otherwise the reference may include more fields, so we'll need to
// assemble the query parameters based on the fields that are set.
var params sqlc.GetInvoiceParams
// Given all invoices are uniquely identified by their payment hash,
// we can use it to query a specific invoice.
if ref.PayHash() != nil {
params.Hash = ref.PayHash()[:]
}
@ -363,18 +417,34 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
} else {
rows, err = db.GetInvoice(ctx, params)
}
switch {
case len(rows) == 0:
return nil, ErrInvoiceNotFound
return sqlc.Invoice{}, ErrInvoiceNotFound
case len(rows) > 1:
// In case the reference is ambiguous, meaning it matches more
// than one invoice, we'll return an error.
return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
ref.String(), spew.Sdump(rows))
return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+
"%s: %s", ref.String(), spew.Sdump(rows))
case err != nil:
return nil, fmt.Errorf("unable to fetch invoice: %w", err)
return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w",
err)
}
return rows[0], nil
}
// fetchInvoice fetches the common invoice data and the AMP state for the
// invoice with the given reference.
func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) (
*Invoice, error) {
// Fetch the invoice from the database.
sqlInvoice, err := getInvoiceByRef(ctx, db, ref)
if err != nil {
return nil, err
}
var (
@ -391,8 +461,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
fetchAmpHtlcs = true
case HtlcSetOnlyModifier:
// In this case we'll fetch all AMP HTLCs for the
// specified set id.
// In this case we'll fetch all AMP HTLCs for the specified set
// id.
if ref.SetID() == nil {
return nil, fmt.Errorf("set ID is required to use " +
"the HTLC set only modifier")
@ -412,8 +482,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
}
// Fetch the rest of the invoice data and fill the invoice struct.
_, invoice, err = fetchInvoiceData(
ctx, db, rows[0], setID, fetchAmpHtlcs,
_, invoice, err := fetchInvoiceData(
ctx, db, sqlInvoice, setID, fetchAmpHtlcs,
)
if err != nil {
return nil, err
@ -616,7 +686,7 @@ func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
invoiceKeys[key] = struct{}{}
if htlc.State != HtlcStateCanceled { //nolint: ll
if htlc.State != HtlcStateCanceled {
amtPaid += htlc.Amt
}
}
@ -646,7 +716,7 @@ func (i *SQLStore) LookupInvoice(ctx context.Context,
readTxOpt := NewSQLInvoiceQueryReadTx()
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoice, err = i.fetchInvoice(ctx, db, ref)
invoice, err = fetchInvoice(ctx, db, ref)
return err
}, func() {})
@ -1347,7 +1417,7 @@ func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
ref.refModifier = HtlcSetOnlyModifier
}
invoice, err := i.fetchInvoice(ctx, db, ref)
invoice, err := fetchInvoice(ctx, db, ref)
if err != nil {
return err
}
@ -1506,13 +1576,6 @@ func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
if len(htlcs) > 0 {
invoice.Htlcs = htlcs
var amountPaid lnwire.MilliSatoshi
for _, htlc := range htlcs {
if htlc.State == HtlcStateSettled {
amountPaid += htlc.Amt
}
}
invoice.AmtPaid = amountPaid
}
return hash, invoice, nil

BIN
invoices/testdata/channel.db vendored Normal file

Binary file not shown.

View File

@ -626,10 +626,6 @@ var allTestCases = []*lntest.TestCase{
Name: "open channel locked balance",
TestFunc: testOpenChannelLockedBalance,
},
{
Name: "nativesql no migration",
TestFunc: testNativeSQLNoMigration,
},
{
Name: "sweep cpfp anchor outgoing timeout",
TestFunc: testSweepCPFPAnchorOutgoingTimeout,
@ -682,6 +678,10 @@ var allTestCases = []*lntest.TestCase{
Name: "quiescence",
TestFunc: testQuiescence,
},
{
Name: "invoice migration",
TestFunc: testInvoiceMigration,
},
}
// appendPrefixed is used to add a prefix to each test name in the subtests

View File

@ -0,0 +1,303 @@
package itest
import (
"database/sql"
"path"
"time"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/kvdb/postgres"
"github.com/lightningnetwork/lnd/kvdb/sqlbase"
"github.com/lightningnetwork/lnd/kvdb/sqlite"
"github.com/lightningnetwork/lnd/lncfg"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
"github.com/lightningnetwork/lnd/lntest"
"github.com/lightningnetwork/lnd/lntest/node"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/stretchr/testify/require"
)
func openChannelDB(ht *lntest.HarnessTest, hn *node.HarnessNode) *channeldb.DB {
sqlbase.Init(0)
var (
backend kvdb.Backend
err error
)
switch hn.Cfg.DBBackend {
case node.BackendSqlite:
backend, err = kvdb.Open(
kvdb.SqliteBackendName,
ht.Context(),
&sqlite.Config{
Timeout: defaultTimeout,
BusyTimeout: defaultTimeout,
},
hn.Cfg.DBDir(), lncfg.SqliteChannelDBName,
lncfg.NSChannelDB,
)
require.NoError(ht, err)
case node.BackendPostgres:
backend, err = kvdb.Open(
kvdb.PostgresBackendName, ht.Context(),
&postgres.Config{
Dsn: hn.Cfg.PostgresDsn,
Timeout: defaultTimeout,
}, lncfg.NSChannelDB,
)
require.NoError(ht, err)
}
db, err := channeldb.CreateWithBackend(backend)
require.NoError(ht, err)
return db
}
func openNativeSQLInvoiceDB(ht *lntest.HarnessTest,
hn *node.HarnessNode) invoices.InvoiceDB {
var db *sqldb.BaseDB
switch hn.Cfg.DBBackend {
case node.BackendSqlite:
sqliteStore, err := sqldb.NewSqliteStore(
&sqldb.SqliteConfig{
Timeout: defaultTimeout,
BusyTimeout: defaultTimeout,
},
path.Join(
hn.Cfg.DBDir(),
lncfg.SqliteNativeDBName,
),
)
require.NoError(ht, err)
db = sqliteStore.BaseDB
case node.BackendPostgres:
postgresStore, err := sqldb.NewPostgresStore(
&sqldb.PostgresConfig{
Dsn: hn.Cfg.PostgresDsn,
Timeout: defaultTimeout,
},
)
require.NoError(ht, err)
db = postgresStore.BaseDB
}
executor := sqldb.NewTransactionExecutor(
db, func(tx *sql.Tx) invoices.SQLInvoiceQueries {
return db.WithTx(tx)
},
)
return invoices.NewSQLStore(
executor, clock.NewDefaultClock(),
)
}
// clampTime truncates the time of the passed invoice to the microsecond level.
func clampTime(invoice *invoices.Invoice) {
trunc := func(t time.Time) time.Time {
return t.Truncate(time.Microsecond)
}
invoice.CreationDate = trunc(invoice.CreationDate)
if !invoice.SettleDate.IsZero() {
invoice.SettleDate = trunc(invoice.SettleDate)
}
if invoice.IsAMP() {
for setID, ampState := range invoice.AMPState {
if ampState.SettleDate.IsZero() {
continue
}
ampState.SettleDate = trunc(ampState.SettleDate)
invoice.AMPState[setID] = ampState
}
}
for _, htlc := range invoice.Htlcs {
if !htlc.AcceptTime.IsZero() {
htlc.AcceptTime = trunc(htlc.AcceptTime)
}
if !htlc.ResolveTime.IsZero() {
htlc.ResolveTime = trunc(htlc.ResolveTime)
}
}
}
// testInvoiceMigration tests that the invoice migration from the old KV store
// to the new native SQL store works as expected.
func testInvoiceMigration(ht *lntest.HarnessTest) {
alice := ht.NewNodeWithCoins("Alice", nil)
bob := ht.NewNodeWithCoins("Bob", []string{"--accept-amp"})
// Make sure we run the test with SQLite or Postgres.
if bob.Cfg.DBBackend != node.BackendSqlite &&
bob.Cfg.DBBackend != node.BackendPostgres {
ht.Skip("node not running with SQLite or Postgres")
}
// Skip the test if the node is already running with native SQL.
if bob.Cfg.NativeSQL {
ht.Skip("node already running with native SQL")
}
ht.EnsureConnected(alice, bob)
cp := ht.OpenChannel(
alice, bob, lntest.OpenChannelParams{
Amt: 1000000,
PushAmt: 500000,
},
)
// Alice and bob should have one channel open with each other now.
ht.AssertNodeNumChannels(alice, 1)
ht.AssertNodeNumChannels(bob, 1)
// Step 1: Add 10 normal invoices and pay 5 of them.
normalInvoices := make([]*lnrpc.AddInvoiceResponse, 10)
for i := 0; i < 10; i++ {
invoice := &lnrpc.Invoice{
Value: int64(1000 + i*100), // Varying amounts
IsAmp: false,
}
resp := bob.RPC.AddInvoice(invoice)
normalInvoices[i] = resp
}
for _, inv := range normalInvoices {
sendReq := &routerrpc.SendPaymentRequest{
PaymentRequest: inv.PaymentRequest,
TimeoutSeconds: 60,
FeeLimitMsat: noFeeLimitMsat,
}
ht.SendPaymentAssertSettled(alice, sendReq)
}
// Step 2: Add 10 AMP invoices and send multiple payments to 5 of them.
ampInvoices := make([]*lnrpc.AddInvoiceResponse, 10)
for i := 0; i < 10; i++ {
invoice := &lnrpc.Invoice{
Value: int64(2000 + i*200), // Varying amounts
IsAmp: true,
}
resp := bob.RPC.AddInvoice(invoice)
ampInvoices[i] = resp
}
// Select the first 5 invoices to send multiple AMP payments.
for i := 0; i < 5; i++ {
inv := ampInvoices[i]
// Send 3 payments to each.
for j := 0; j < 3; j++ {
payReq := &routerrpc.SendPaymentRequest{
PaymentRequest: inv.PaymentRequest,
TimeoutSeconds: 60,
FeeLimitMsat: noFeeLimitMsat,
Amp: true,
}
// Send a normal AMP payment first, then a spontaneous
// AMP payment.
ht.SendPaymentAssertSettled(alice, payReq)
// Generate an external payment address when attempting
// to pseudo-reuse an AMP invoice. When using an
// external payment address, we'll also expect an extra
// invoice to appear in the ListInvoices response, since
// a new invoice will be JIT inserted under a different
// payment address than the one in the invoice.
//
// NOTE: This will only work when the peer has
// spontaneous AMP payments enabled otherwise no invoice
// under a different payment_addr will be found.
payReq.PaymentAddr = ht.Random32Bytes()
ht.SendPaymentAssertSettled(alice, payReq)
}
}
// We can close the channel now.
ht.CloseChannel(alice, cp)
// Now stop Bob so we can open the DB for examination.
require.NoError(ht, bob.Stop())
// Open the KV channel DB.
db := openChannelDB(ht, bob)
query := invoices.InvoiceQuery{
IndexOffset: 0,
// As a sanity check, fetch more invoices than we have
// to ensure that we did not add any extra invoices.
NumMaxInvoices: 9999,
}
// Fetch all invoices and make sure we have 35 in total.
result1, err := db.QueryInvoices(ht.Context(), query)
require.NoError(ht, err)
require.Equal(ht, 35, len(result1.Invoices))
numInvoices := len(result1.Invoices)
bob.SetExtraArgs([]string{"--db.use-native-sql"})
// Now run the migration flow three times to ensure that each run is
// idempotent.
for i := 0; i < 3; i++ {
// Start bob with the native SQL flag set. This will trigger the
// migration to run.
require.NoError(ht, bob.Start(ht.Context()))
// At this point the migration should have completed and the
// node should be running with native SQL. Now we'll stop Bob
// again so we can safely examine the database.
require.NoError(ht, bob.Stop())
// Now we'll open the database with the native SQL backend and
// fetch the invoices again to ensure that they were migrated
// correctly.
sqlInvoiceDB := openNativeSQLInvoiceDB(ht, bob)
result2, err := sqlInvoiceDB.QueryInvoices(ht.Context(), query)
require.NoError(ht, err)
require.Equal(ht, numInvoices, len(result2.Invoices))
// Simply zero out the add index so we don't fail on that when
// comparing.
for i := 0; i < numInvoices; i++ {
result1.Invoices[i].AddIndex = 0
result2.Invoices[i].AddIndex = 0
// Clamp the precision to microseconds. Note that we
// need to override both invoices as the original
// invoice is coming from KV database, it was stored as
// a binary serialized Go time.Time value which has
// nanosecond precision. The migrated invoice is stored
// in SQL in PostgreSQL has microsecond precision while
// in SQLite it has nanosecond precision if using TEXT
// storage class.
clampTime(&result1.Invoices[i])
clampTime(&result2.Invoices[i])
require.Equal(
ht, result1.Invoices[i], result2.Invoices[i],
)
}
}
// Start Bob again so the test can complete.
require.NoError(ht, bob.Start(ht.Context()))
}

View File

@ -1,7 +1,6 @@
package itest
import (
"context"
"encoding/hex"
"fmt"
"os"
@ -1243,44 +1242,6 @@ func testSignVerifyMessageWithAddr(ht *lntest.HarnessTest) {
require.False(ht, respValid.Valid, "external signature did validate")
}
// testNativeSQLNoMigration tests that nodes that have invoices would not start
// up with native SQL enabled, as we don't currently support migration of KV
// invoices to the new SQL schema.
func testNativeSQLNoMigration(ht *lntest.HarnessTest) {
alice := ht.NewNode("Alice", nil)
// Make sure we run the test with SQLite or Postgres.
if alice.Cfg.DBBackend != node.BackendSqlite &&
alice.Cfg.DBBackend != node.BackendPostgres {
ht.Skip("node not running with SQLite or Postgres")
}
// Skip the test if the node is already running with native SQL.
if alice.Cfg.NativeSQL {
ht.Skip("node already running with native SQL")
}
alice.RPC.AddInvoice(&lnrpc.Invoice{
Value: 10_000,
})
alice.SetExtraArgs([]string{"--db.use-native-sql"})
// Restart the node manually as we're really only interested in the
// startup error.
require.NoError(ht, alice.Stop())
require.NoError(ht, alice.StartLndCmd(context.Background()))
// We expect the node to fail to start up with native SQL enabled, as we
// have an invoice in the KV store.
require.Error(ht, alice.WaitForProcessExit())
// Reset the extra args and restart alice.
alice.SetExtraArgs(nil)
require.NoError(ht, alice.Start(ht.Context()))
}
// testSendSelectedCoins tests that we're able to properly send the selected
// coins from the wallet to a single target address.
func testSendSelectedCoins(ht *lntest.HarnessTest) {

View File

@ -87,6 +87,8 @@ type DB struct {
UseNativeSQL bool `long:"use-native-sql" description:"Use native SQL for tables that already support it."`
SkipSQLInvoiceMigration bool `long:"skip-sql-invoice-migration" description:"Do not migrate invoices stored in our key-value database to native SQL."`
NoGraphCache bool `long:"no-graph-cache" description:"Don't use the in-memory graph cache for path finding. Much slower but uses less RAM. Can only be used with a bolt database backend."`
PruneRevocation bool `long:"prune-revocation" description:"Run the optional migration that prunes the revocation logs to save disk space."`
@ -115,7 +117,8 @@ func DefaultDB() *DB {
MaxConnections: defaultSqliteMaxConnections,
BusyTimeout: defaultSqliteBusyTimeout,
},
UseNativeSQL: false,
UseNativeSQL: false,
SkipSQLInvoiceMigration: false,
}
}
@ -231,10 +234,10 @@ type DatabaseBackends struct {
// the underlying wallet database from.
WalletDB btcwallet.LoaderOption
// NativeSQLStore is a pointer to a native SQL store that can be used
// for native SQL queries for tables that already support it. This may
// be nil if the use-native-sql flag was not set.
NativeSQLStore *sqldb.BaseDB
// NativeSQLStore holds a reference to the native SQL store that can
// be used for native SQL queries for tables that already support it.
// This may be nil if the use-native-sql flag was not set.
NativeSQLStore sqldb.DB
// Remote indicates whether the database backends are remote, possibly
// replicated instances or local bbolt or sqlite backed databases.
@ -449,7 +452,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
}
closeFuncs[NSWalletDB] = postgresWalletBackend.Close
var nativeSQLStore *sqldb.BaseDB
var nativeSQLStore sqldb.DB
if db.UseNativeSQL {
nativePostgresStore, err := sqldb.NewPostgresStore(
db.Postgres,
@ -459,7 +462,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
"native postgres store: %v", err)
}
nativeSQLStore = nativePostgresStore.BaseDB
nativeSQLStore = nativePostgresStore
closeFuncs[PostgresBackend] = nativePostgresStore.Close
}
@ -571,7 +574,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
}
closeFuncs[NSWalletDB] = sqliteWalletBackend.Close
var nativeSQLStore *sqldb.BaseDB
var nativeSQLStore sqldb.DB
if db.UseNativeSQL {
nativeSQLiteStore, err := sqldb.NewSqliteStore(
db.Sqlite,
@ -582,7 +585,7 @@ func (db *DB) GetBackends(ctx context.Context, chanDBPath,
"native SQLite store: %v", err)
}
nativeSQLStore = nativeSQLiteStore.BaseDB
nativeSQLStore = nativeSQLiteStore
closeFuncs[SqliteBackend] = nativeSQLiteStore.Close
}

View File

@ -1472,6 +1472,9 @@
; own risk.
; db.use-native-sql=false
; If set to true, native SQL invoice migration will be skipped. Note that this
; option is intended for users who experience non-resolvable migration errors.
; db.skip-sql-invoice-migration=false
[etcd]

View File

@ -2,12 +2,40 @@
set -e
# restore_files is a function to restore original schema files.
restore_files() {
echo "Restoring SQLite bigint patch..."
for file in sqldb/sqlc/migrations/*.up.sql.bak; do
mv "$file" "${file%.bak}"
done
}
# Set trap to call restore_files on script exit. This makes sure the old files
# are always restored.
trap restore_files EXIT
# Directory of the script file, independent of where it's called from.
DIR="$(cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd)"
# Use the user's cache directories
GOCACHE=`go env GOCACHE`
GOMODCACHE=`go env GOMODCACHE`
# SQLite doesn't support "BIGINT PRIMARY KEY" for auto-incrementing primary
# keys, only "INTEGER PRIMARY KEY". Internally it uses 64-bit integers for
# numbers anyway, independent of the column type. So we can just use
# "INTEGER PRIMARY KEY" and it will work the same under the hood, giving us
# auto incrementing 64-bit integers.
# _BUT_, sqlc will generate Go code with int32 if we use "INTEGER PRIMARY KEY",
# even though we want int64. So before we run sqlc, we need to patch the
# source schema SQL files to use "BIGINT PRIMARY KEY" instead of "INTEGER
# PRIMARY KEY".
echo "Applying SQLite bigint patch..."
for file in sqldb/sqlc/migrations/*.up.sql; do
echo "Patching $file"
sed -i.bak -E 's/INTEGER PRIMARY KEY/BIGINT PRIMARY KEY/g' "$file"
done
echo "Generating sql models and queries in go..."
docker run \

View File

@ -355,6 +355,18 @@ func (t *TransactionExecutor[Q]) ExecTx(ctx context.Context,
)
}
// DB is an interface that represents a generic SQL database. It provides
// methods to apply migrations and access the underlying database connection.
type DB interface {
// GetBaseDB returns the underlying BaseDB instance.
GetBaseDB() *BaseDB
// ApplyAllMigrations applies all migrations to the database including
// both sqlc and custom in-code migrations.
ApplyAllMigrations(ctx context.Context,
customMigrations []MigrationConfig) error
}
// BaseDB is the base database struct that each implementation can embed to
// gain some common functionality.
type BaseDB struct {

View File

@ -2,22 +2,118 @@ package sqldb
import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"strings"
"time"
"github.com/btcsuite/btclog/v2"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/source/httpfs"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
)
var (
// migrationConfig defines a list of migrations to be applied to the
// database. Each migration is assigned a version number, determining
// its execution order.
// The schema version, tracked by golang-migrate, ensures migrations are
// applied to the correct schema. For migrations involving only schema
// changes, the migration function can be left nil. For custom
// migrations an implemented migration function is required.
//
// NOTE: The migration function may have runtime dependencies, which
// must be injected during runtime.
migrationConfig = []MigrationConfig{
{
Name: "000001_invoices",
Version: 1,
SchemaVersion: 1,
},
{
Name: "000002_amp_invoices",
Version: 2,
SchemaVersion: 2,
},
{
Name: "000003_invoice_events",
Version: 3,
SchemaVersion: 3,
},
{
Name: "000004_invoice_expiry_fix",
Version: 4,
SchemaVersion: 4,
},
{
Name: "000005_migration_tracker",
Version: 5,
SchemaVersion: 5,
},
{
Name: "000006_invoice_migration",
Version: 6,
SchemaVersion: 6,
},
{
Name: "kv_invoice_migration",
Version: 7,
SchemaVersion: 6,
// A migration function is may be attached to this
// migration to migrate KV invoices to the native SQL
// schema. This is optional and can be disabled by the
// user if necessary.
},
}
)
// MigrationConfig is a configuration struct that describes SQL migrations. Each
// migration is associated with a specific schema version and a global database
// version. Migrations are applied in the order of their global database
// version. If a migration includes a non-nil MigrationFn, it is executed after
// the SQL schema has been migrated to the corresponding schema version.
type MigrationConfig struct {
// Name is the name of the migration.
Name string
// Version represents the "global" database version for this migration.
// Unlike the schema version tracked by golang-migrate, it encompasses
// all migrations, including those managed by golang-migrate as well
// as custom in-code migrations.
Version int
// SchemaVersion represents the schema version tracked by golang-migrate
// at which the migration is applied.
SchemaVersion int
// MigrationFn is the function executed for custom migrations at the
// specified version. It is used to handle migrations that cannot be
// performed through SQL alone. If set to nil, no custom migration is
// applied.
MigrationFn func(tx *sqlc.Queries) error
}
// MigrationTarget is a functional option that can be passed to applyMigrations
// to specify a target version to migrate to.
type MigrationTarget func(mig *migrate.Migrate) error
// MigrationExecutor is an interface that abstracts the migration functionality.
type MigrationExecutor interface {
// ExecuteMigrations runs database migrations up to the specified target
// version or all migrations if no target is specified. A migration may
// include a schema change, a custom migration function, or both.
// Developers must ensure that migrations are defined in the correct
// order. Migration details are stored in the global variable
// migrationConfig.
ExecuteMigrations(target MigrationTarget) error
}
var (
// TargetLatest is a MigrationTarget that migrates to the latest
// version available.
@ -34,6 +130,14 @@ var (
}
)
// GetMigrations returns a copy of the migration configuration.
func GetMigrations() []MigrationConfig {
migrations := make([]MigrationConfig, len(migrationConfig))
copy(migrations, migrationConfig)
return migrations
}
// migrationLogger is a logger that wraps the passed btclog.Logger so it can be
// used to log migrations.
type migrationLogger struct {
@ -216,3 +320,117 @@ func (t *replacerFile) Close() error {
// instance, so there's nothing to do for us here.
return nil
}
// MigrationTxOptions is the implementation of the TxOptions interface for
// migration transactions.
type MigrationTxOptions struct {
}
// ReadOnly returns false to indicate that migration transactions are not read
// only.
func (m *MigrationTxOptions) ReadOnly() bool {
return false
}
// ApplyMigrations applies the provided migrations to the database in sequence.
// It ensures migrations are executed in the correct order, applying both custom
// migration functions and SQL migrations as needed.
func ApplyMigrations(ctx context.Context, db *BaseDB,
migrator MigrationExecutor, migrations []MigrationConfig) error {
// Ensure that the migrations are sorted by version.
for i := 0; i < len(migrations); i++ {
if migrations[i].Version != i+1 {
return fmt.Errorf("migration version %d is out of "+
"order. Expected %d", migrations[i].Version,
i+1)
}
}
// Construct a transaction executor to apply custom migrations.
executor := NewTransactionExecutor(db, func(tx *sql.Tx) *sqlc.Queries {
return db.WithTx(tx)
})
currentVersion := 0
version, err := db.GetDatabaseVersion(ctx)
if !errors.Is(err, sql.ErrNoRows) {
if err != nil {
return fmt.Errorf("error getting current database "+
"version: %w", err)
}
currentVersion = int(version)
}
for _, migration := range migrations {
if migration.Version <= currentVersion {
log.Infof("Skipping migration '%s' (version %d) as it "+
"has already been applied", migration.Name,
migration.Version)
continue
}
log.Infof("Migrating SQL schema to version %d",
migration.SchemaVersion)
// Execute SQL schema migrations up to the target version.
err = migrator.ExecuteMigrations(
TargetVersion(uint(migration.SchemaVersion)),
)
if err != nil {
return fmt.Errorf("error executing schema migrations "+
"to target version %d: %w",
migration.SchemaVersion, err)
}
var opts MigrationTxOptions
// Run the custom migration as a transaction to ensure
// atomicity. If successful, mark the migration as complete in
// the migration tracker table.
err = executor.ExecTx(ctx, &opts, func(tx *sqlc.Queries) error {
// Apply the migration function if one is provided.
if migration.MigrationFn != nil {
log.Infof("Applying custom migration '%v' "+
"(version %d) to schema version %d",
migration.Name, migration.Version,
migration.SchemaVersion)
err = migration.MigrationFn(tx)
if err != nil {
return fmt.Errorf("error applying "+
"migration '%v' (version %d) "+
"to schema version %d: %w",
migration.Name,
migration.Version,
migration.SchemaVersion, err)
}
log.Infof("Migration '%v' (version %d) "+
"applied ", migration.Name,
migration.Version)
}
// Mark the migration as complete by adding the version
// to the migration tracker table along with the current
// timestamp.
err = tx.SetMigration(ctx, sqlc.SetMigrationParams{
Version: int32(migration.Version),
MigrationTime: time.Now(),
})
if err != nil {
return fmt.Errorf("error setting migration "+
"version %d: %w", migration.Version,
err)
}
return nil
}, func() {})
if err != nil {
return err
}
}
return nil
}

View File

@ -2,8 +2,15 @@ package sqldb
import (
"context"
"database/sql"
"fmt"
"path/filepath"
"testing"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
pgx_migrate "github.com/golang-migrate/migrate/v4/database/pgx/v5"
sqlite_migrate "github.com/golang-migrate/migrate/v4/database/sqlite"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
"github.com/stretchr/testify/require"
)
@ -152,3 +159,296 @@ func testInvoiceExpiryMigration(t *testing.T, makeDB makeMigrationTestDB) {
require.NoError(t, err)
require.Equal(t, expected, invoices)
}
// TestCustomMigration tests that a custom in-code migrations are correctly
// executed during the migration process.
func TestCustomMigration(t *testing.T) {
var customMigrationLog []string
logMigration := func(name string) {
customMigrationLog = append(customMigrationLog, name)
}
// Some migrations to use for both the failure and success tests. Note
// that the migrations are not in order to test that they are executed
// in the correct order.
migrations := []MigrationConfig{
{
Name: "1",
Version: 1,
SchemaVersion: 1,
MigrationFn: func(*sqlc.Queries) error {
logMigration("1")
return nil
},
},
{
Name: "2",
Version: 2,
SchemaVersion: 1,
MigrationFn: func(*sqlc.Queries) error {
logMigration("2")
return nil
},
},
{
Name: "3",
Version: 3,
SchemaVersion: 2,
MigrationFn: func(*sqlc.Queries) error {
logMigration("3")
return nil
},
},
}
tests := []struct {
name string
migrations []MigrationConfig
expectedSuccess bool
expectedMigrationLog []string
expectedSchemaVersion int
expectedVersion int
}{
{
name: "success",
migrations: migrations,
expectedSuccess: true,
expectedMigrationLog: []string{"1", "2", "3"},
expectedSchemaVersion: 2,
expectedVersion: 3,
},
{
name: "unordered migrations",
migrations: append([]MigrationConfig{
{
Name: "4",
Version: 4,
SchemaVersion: 3,
MigrationFn: func(*sqlc.Queries) error {
logMigration("4")
return nil
},
},
}, migrations...),
expectedSuccess: false,
expectedMigrationLog: nil,
expectedSchemaVersion: 0,
},
{
name: "failure of migration 4",
migrations: append(migrations, MigrationConfig{
Name: "4",
Version: 4,
SchemaVersion: 3,
MigrationFn: func(*sqlc.Queries) error {
return fmt.Errorf("migration 4 failed")
},
}),
expectedSuccess: false,
expectedMigrationLog: []string{"1", "2", "3"},
// Since schema migration is a separate step we expect
// that migrating up to 3 succeeded.
expectedSchemaVersion: 3,
// We still remain on version 3 though.
expectedVersion: 3,
},
{
name: "success of migration 4",
migrations: append(migrations, MigrationConfig{
Name: "4",
Version: 4,
SchemaVersion: 3,
MigrationFn: func(*sqlc.Queries) error {
logMigration("4")
return nil
},
}),
expectedSuccess: true,
expectedMigrationLog: []string{"1", "2", "3", "4"},
expectedSchemaVersion: 3,
expectedVersion: 4,
},
}
ctxb := context.Background()
for _, test := range tests {
// checkSchemaVersion checks the database schema version against
// the expected version.
getSchemaVersion := func(t *testing.T,
driver database.Driver, dbName string) int {
sqlMigrate, err := migrate.NewWithInstance(
"migrations", nil, dbName, driver,
)
require.NoError(t, err)
version, _, err := sqlMigrate.Version()
if err != migrate.ErrNilVersion {
require.NoError(t, err)
}
return int(version)
}
t.Run("SQLite "+test.name, func(t *testing.T) {
customMigrationLog = nil
// First instantiate the database and run the migrations
// including the custom migrations.
t.Logf("Creating new SQLite DB for testing migrations")
dbFileName := filepath.Join(t.TempDir(), "tmp.db")
var (
db *SqliteStore
err error
)
// Run the migration 3 times to test that the migrations
// are idempotent.
for i := 0; i < 3; i++ {
db, err = NewSqliteStore(&SqliteConfig{
SkipMigrations: false,
}, dbFileName)
require.NoError(t, err)
dbToCleanup := db.DB
t.Cleanup(func() {
require.NoError(
t, dbToCleanup.Close(),
)
})
err = db.ApplyAllMigrations(
ctxb, test.migrations,
)
if test.expectedSuccess {
require.NoError(t, err)
} else {
require.Error(t, err)
// Also repoen the DB without migrations
// so we can read versions.
db, err = NewSqliteStore(&SqliteConfig{
SkipMigrations: true,
}, dbFileName)
require.NoError(t, err)
}
require.Equal(t,
test.expectedMigrationLog,
customMigrationLog,
)
// Create the migration executor to be able to
// query the current schema version.
driver, err := sqlite_migrate.WithInstance(
db.DB, &sqlite_migrate.Config{},
)
require.NoError(t, err)
require.Equal(
t, test.expectedSchemaVersion,
getSchemaVersion(t, driver, ""),
)
// Check the migraton version in the database.
version, err := db.GetDatabaseVersion(ctxb)
if test.expectedSchemaVersion != 0 {
require.NoError(t, err)
} else {
require.Equal(t, sql.ErrNoRows, err)
}
require.Equal(
t, test.expectedVersion, int(version),
)
}
})
t.Run("Postgres "+test.name, func(t *testing.T) {
customMigrationLog = nil
// First create a temporary Postgres database to run
// the migrations on.
fixture := NewTestPgFixture(
t, DefaultPostgresFixtureLifetime,
)
t.Cleanup(func() {
fixture.TearDown(t)
})
dbName := randomDBName(t)
// Next instantiate the database and run the migrations
// including the custom migrations.
t.Logf("Creating new Postgres DB '%s' for testing "+
"migrations", dbName)
_, err := fixture.db.ExecContext(
context.Background(), "CREATE DATABASE "+dbName,
)
require.NoError(t, err)
cfg := fixture.GetConfig(dbName)
var db *PostgresStore
// Run the migration 3 times to test that the migrations
// are idempotent.
for i := 0; i < 3; i++ {
cfg.SkipMigrations = false
db, err = NewPostgresStore(cfg)
require.NoError(t, err)
err = db.ApplyAllMigrations(
ctxb, test.migrations,
)
if test.expectedSuccess {
require.NoError(t, err)
} else {
require.Error(t, err)
// Also repoen the DB without migrations
// so we can read versions.
cfg.SkipMigrations = true
db, err = NewPostgresStore(cfg)
require.NoError(t, err)
}
require.Equal(t,
test.expectedMigrationLog,
customMigrationLog,
)
// Create the migration executor to be able to
// query the current version.
driver, err := pgx_migrate.WithInstance(
db.DB, &pgx_migrate.Config{},
)
require.NoError(t, err)
require.Equal(
t, test.expectedSchemaVersion,
getSchemaVersion(t, driver, ""),
)
// Check the migraton version in the database.
version, err := db.GetDatabaseVersion(ctxb)
if test.expectedSchemaVersion != 0 {
require.NoError(t, err)
} else {
require.Equal(t, sql.ErrNoRows, err)
}
require.Equal(
t, test.expectedVersion, int(version),
)
}
})
}
}

View File

@ -2,7 +2,15 @@
package sqldb
import "fmt"
import (
"context"
"fmt"
)
var (
// Make sure SqliteStore implements the DB interface.
_ DB = (*SqliteStore)(nil)
)
// SqliteStore is a database store implementation that uses a sqlite backend.
type SqliteStore struct {
@ -16,3 +24,17 @@ type SqliteStore struct {
func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
return nil, fmt.Errorf("SQLite backend not supported in WebAssembly")
}
// GetBaseDB returns the underlying BaseDB instance for the SQLite store.
// It is a trivial helper method to comply with the sqldb.DB interface.
func (s *SqliteStore) GetBaseDB() *BaseDB {
return s.BaseDB
}
// ApplyAllMigrations applies both the SQLC and custom in-code migrations to
// the SQLite database.
func (s *SqliteStore) ApplyAllMigrations(context.Context,
[]MigrationConfig) error {
return fmt.Errorf("SQLite backend not supported in WebAssembly")
}

View File

@ -1,6 +1,7 @@
package sqldb
import (
"context"
"database/sql"
"fmt"
"net/url"
@ -28,10 +29,15 @@ var (
// has some differences.
postgresSchemaReplacements = map[string]string{
"BLOB": "BYTEA",
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
"INTEGER PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
}
// Make sure PostgresStore implements the MigrationExecutor interface.
_ MigrationExecutor = (*PostgresStore)(nil)
// Make sure PostgresStore implements the DB interface.
_ DB = (*PostgresStore)(nil)
)
// replacePasswordInDSN takes a DSN string and returns it with the password
@ -92,40 +98,64 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
}
log.Infof("Using SQL database '%s'", sanitizedDSN)
rawDB, err := sql.Open("pgx", cfg.Dsn)
db, err := sql.Open("pgx", cfg.Dsn)
if err != nil {
return nil, err
}
// Ensure the migration tracker table exists before running migrations.
// This table tracks migration progress and ensures compatibility with
// SQLC query generation. If the table is already created by an SQLC
// migration, this operation becomes a no-op.
migrationTrackerSQL := `
CREATE TABLE IF NOT EXISTS migration_tracker (
version INTEGER UNIQUE NOT NULL,
migration_time TIMESTAMP NOT NULL
);`
_, err = db.Exec(migrationTrackerSQL)
if err != nil {
return nil, fmt.Errorf("error creating migration tracker: %w",
err)
}
maxConns := defaultMaxConns
if cfg.MaxConnections > 0 {
maxConns = cfg.MaxConnections
}
rawDB.SetMaxOpenConns(maxConns)
rawDB.SetMaxIdleConns(maxConns)
rawDB.SetConnMaxLifetime(connIdleLifetime)
db.SetMaxOpenConns(maxConns)
db.SetMaxIdleConns(maxConns)
db.SetConnMaxLifetime(connIdleLifetime)
queries := sqlc.New(rawDB)
queries := sqlc.New(db)
s := &PostgresStore{
return &PostgresStore{
cfg: cfg,
BaseDB: &BaseDB{
DB: rawDB,
DB: db,
Queries: queries,
},
}
}, nil
}
// GetBaseDB returns the underlying BaseDB instance for the Postgres store.
// It is a trivial helper method to comply with the sqldb.DB interface.
func (s *PostgresStore) GetBaseDB() *BaseDB {
return s.BaseDB
}
// ApplyAllMigrations applies both the SQLC and custom in-code migrations to the
// Postgres database.
func (s *PostgresStore) ApplyAllMigrations(ctx context.Context,
migrations []MigrationConfig) error {
// Execute migrations unless configured to skip them.
if !cfg.SkipMigrations {
err := s.ExecuteMigrations(TargetLatest)
if err != nil {
return nil, fmt.Errorf("error executing migrations: %w",
err)
}
if s.cfg.SkipMigrations {
return nil
}
return s, nil
return ApplyMigrations(ctx, s.BaseDB, s, migrations)
}
// ExecuteMigrations runs migrations for the Postgres database, depending on the

View File

@ -59,7 +59,7 @@ func NewTestPgFixture(t *testing.T, expiry time.Duration) *TestPgFixture {
"postgres",
"-c", "log_statement=all",
"-c", "log_destination=stderr",
"-c", "max_connections=1000",
"-c", "max_connections=5000",
},
}, func(config *docker.HostConfig) {
// Set AutoRemove to true so that stopped container goes away
@ -151,6 +151,10 @@ func NewTestPostgresDB(t *testing.T, fixture *TestPgFixture) *PostgresStore {
store, err := NewPostgresStore(cfg)
require.NoError(t, err)
require.NoError(t, store.ApplyAllMigrations(
context.Background(), GetMigrations()),
)
return store
}

View File

@ -235,6 +235,35 @@ func (q *Queries) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, err
return invoice_id, err
}
const insertAMPSubInvoice = `-- name: InsertAMPSubInvoice :exec
INSERT INTO amp_sub_invoices (
set_id, state, created_at, settled_at, settle_index, invoice_id
) VALUES (
$1, $2, $3, $4, $5, $6
)
`
type InsertAMPSubInvoiceParams struct {
SetID []byte
State int16
CreatedAt time.Time
SettledAt sql.NullTime
SettleIndex sql.NullInt64
InvoiceID int64
}
func (q *Queries) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error {
_, err := q.db.ExecContext(ctx, insertAMPSubInvoice,
arg.SetID,
arg.State,
arg.CreatedAt,
arg.SettledAt,
arg.SettleIndex,
arg.InvoiceID,
)
return err
}
const insertAMPSubInvoiceHTLC = `-- name: InsertAMPSubInvoiceHTLC :exec
INSERT INTO amp_sub_invoice_htlcs (
invoice_id, set_id, htlc_id, root_share, child_index, hash, preimage

View File

@ -11,6 +11,15 @@ import (
"time"
)
const clearKVInvoiceHashIndex = `-- name: ClearKVInvoiceHashIndex :exec
DELETE FROM invoice_payment_hashes
`
func (q *Queries) ClearKVInvoiceHashIndex(ctx context.Context) error {
_, err := q.db.ExecContext(ctx, clearKVInvoiceHashIndex)
return err
}
const deleteCanceledInvoices = `-- name: DeleteCanceledInvoices :execresult
DELETE
FROM invoices
@ -182,11 +191,8 @@ WHERE (
i.hash = $3 OR
$3 IS NULL
) AND (
i.preimage = $4 OR
i.payment_addr = $4 OR
$4 IS NULL
) AND (
i.payment_addr = $5 OR
$5 IS NULL
)
GROUP BY i.id
LIMIT 2
@ -196,7 +202,6 @@ type GetInvoiceParams struct {
SetID []byte
AddIndex sql.NullInt64
Hash []byte
Preimage []byte
PaymentAddr []byte
}
@ -208,7 +213,6 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
arg.SetID,
arg.AddIndex,
arg.Hash,
arg.Preimage,
arg.PaymentAddr,
)
if err != nil {
@ -251,6 +255,38 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
return items, nil
}
const getInvoiceByHash = `-- name: GetInvoiceByHash :one
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
FROM invoices i
WHERE i.hash = $1
`
func (q *Queries) GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error) {
row := q.db.QueryRowContext(ctx, getInvoiceByHash, hash)
var i Invoice
err := row.Scan(
&i.ID,
&i.Hash,
&i.Preimage,
&i.SettleIndex,
&i.SettledAt,
&i.Memo,
&i.AmountMsat,
&i.CltvDelta,
&i.Expiry,
&i.PaymentAddr,
&i.PaymentRequest,
&i.PaymentRequestHash,
&i.State,
&i.AmountPaidMsat,
&i.IsAmp,
&i.IsHodl,
&i.IsKeysend,
&i.CreatedAt,
)
return i, err
}
const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
FROM invoices i
@ -405,6 +441,19 @@ func (q *Queries) GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]Invoi
return items, nil
}
const getKVInvoicePaymentHashByAddIndex = `-- name: GetKVInvoicePaymentHashByAddIndex :one
SELECT hash
FROM invoice_payment_hashes
WHERE add_index = $1
`
func (q *Queries) GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error) {
row := q.db.QueryRowContext(ctx, getKVInvoicePaymentHashByAddIndex, addIndex)
var hash []byte
err := row.Scan(&hash)
return hash, err
}
const insertInvoice = `-- name: InsertInvoice :one
INSERT INTO invoices (
hash, preimage, memo, amount_msat, cltv_delta, expiry, payment_addr,
@ -533,6 +582,79 @@ func (q *Queries) InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertI
return err
}
const insertKVInvoiceKeyAndAddIndex = `-- name: InsertKVInvoiceKeyAndAddIndex :exec
INSERT INTO invoice_payment_hashes (
id, add_index
) VALUES (
$1, $2
)
`
type InsertKVInvoiceKeyAndAddIndexParams struct {
ID int64
AddIndex int64
}
func (q *Queries) InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error {
_, err := q.db.ExecContext(ctx, insertKVInvoiceKeyAndAddIndex, arg.ID, arg.AddIndex)
return err
}
const insertMigratedInvoice = `-- name: InsertMigratedInvoice :one
INSERT INTO invoices (
hash, preimage, settle_index, settled_at, memo, amount_msat, cltv_delta,
expiry, payment_addr, payment_request, payment_request_hash, state,
amount_paid_msat, is_amp, is_hodl, is_keysend, created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
) RETURNING id
`
type InsertMigratedInvoiceParams struct {
Hash []byte
Preimage []byte
SettleIndex sql.NullInt64
SettledAt sql.NullTime
Memo sql.NullString
AmountMsat int64
CltvDelta sql.NullInt32
Expiry int32
PaymentAddr []byte
PaymentRequest sql.NullString
PaymentRequestHash []byte
State int16
AmountPaidMsat int64
IsAmp bool
IsHodl bool
IsKeysend bool
CreatedAt time.Time
}
func (q *Queries) InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error) {
row := q.db.QueryRowContext(ctx, insertMigratedInvoice,
arg.Hash,
arg.Preimage,
arg.SettleIndex,
arg.SettledAt,
arg.Memo,
arg.AmountMsat,
arg.CltvDelta,
arg.Expiry,
arg.PaymentAddr,
arg.PaymentRequest,
arg.PaymentRequestHash,
arg.State,
arg.AmountPaidMsat,
arg.IsAmp,
arg.IsHodl,
arg.IsKeysend,
arg.CreatedAt,
)
var id int64
err := row.Scan(&id)
return id, err
}
const nextInvoiceSettleIndex = `-- name: NextInvoiceSettleIndex :one
UPDATE invoice_sequences SET current_value = current_value + 1
WHERE name = 'settle_index'
@ -546,6 +668,22 @@ func (q *Queries) NextInvoiceSettleIndex(ctx context.Context) (int64, error) {
return current_value, err
}
const setKVInvoicePaymentHash = `-- name: SetKVInvoicePaymentHash :exec
UPDATE invoice_payment_hashes
SET hash = $2
WHERE id = $1
`
type SetKVInvoicePaymentHashParams struct {
ID int64
Hash []byte
}
func (q *Queries) SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error {
_, err := q.db.ExecContext(ctx, setKVInvoicePaymentHash, arg.ID, arg.Hash)
return err
}
const updateInvoiceAmountPaid = `-- name: UpdateInvoiceAmountPaid :execresult
UPDATE invoices
SET amount_paid_msat = $2

View File

@ -0,0 +1,60 @@
// Code generated by sqlc. DO NOT EDIT.
// versions:
// sqlc v1.25.0
// source: migration.sql
package sqlc
import (
"context"
"time"
)
const getDatabaseVersion = `-- name: GetDatabaseVersion :one
SELECT
version
FROM
migration_tracker
ORDER BY
version DESC
LIMIT 1
`
func (q *Queries) GetDatabaseVersion(ctx context.Context) (int32, error) {
row := q.db.QueryRowContext(ctx, getDatabaseVersion)
var version int32
err := row.Scan(&version)
return version, err
}
const getMigration = `-- name: GetMigration :one
SELECT
migration_time
FROM
migration_tracker
WHERE
version = $1
`
func (q *Queries) GetMigration(ctx context.Context, version int32) (time.Time, error) {
row := q.db.QueryRowContext(ctx, getMigration, version)
var migration_time time.Time
err := row.Scan(&migration_time)
return migration_time, err
}
const setMigration = `-- name: SetMigration :exec
INSERT INTO
migration_tracker (version, migration_time)
VALUES ($1, $2)
`
type SetMigrationParams struct {
Version int32
MigrationTime time.Time
}
func (q *Queries) SetMigration(ctx context.Context, arg SetMigrationParams) error {
_, err := q.db.ExecContext(ctx, setMigration, arg.Version, arg.MigrationTime)
return err
}

View File

@ -11,7 +11,7 @@ INSERT INTO invoice_sequences(name, current_value) VALUES ('settle_index', 0);
-- invoices table contains all the information shared by all the invoice types.
CREATE TABLE IF NOT EXISTS invoices (
-- The id of the invoice. Translates to the AddIndex.
id BIGINT PRIMARY KEY,
id INTEGER PRIMARY KEY,
-- The hash for this invoice. The invoice hash will always identify that
-- invoice.
@ -102,8 +102,8 @@ CREATE INDEX IF NOT EXISTS invoice_feature_invoice_id_idx ON invoice_features(in
CREATE TABLE IF NOT EXISTS invoice_htlcs (
-- The id for this htlc. Used in foreign keys instead of the
-- htlc_id/chan_id combination.
id BIGINT PRIMARY KEY,
id INTEGER PRIMARY KEY,
-- Short chan id indicating the htlc's origin. uint64 stored as text.
chan_id TEXT NOT NULL,
@ -111,7 +111,7 @@ CREATE TABLE IF NOT EXISTS invoice_htlcs (
-- int64 in the database. The application layer must check that there is no
-- overflow when storing/loading this column.
htlc_id BIGINT NOT NULL,
-- The htlc's amount in millisatoshis.
amount_msat BIGINT NOT NULL,

View File

@ -29,7 +29,7 @@ VALUES
-- AMP sub invoices. This table can be used to create a historical view of what
-- happened to the node's invoices.
CREATE TABLE IF NOT EXISTS invoice_events (
id BIGINT PRIMARY KEY,
id INTEGER PRIMARY KEY,
-- added_at is the timestamp when this event was added.
added_at TIMESTAMP NOT NULL,

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS migration_tracker;

View File

@ -0,0 +1,17 @@
-- The migration_tracker table keeps track of migrations that have been applied
-- to the database. This table ensures that migrations are idempotent and are
-- only run once. It tracks a global database version that encompasses both
-- schema migrations handled by golang-migrate and custom in-code migrations
-- for more complex data conversions that cannot be expressed in pure SQL.
CREATE TABLE IF NOT EXISTS migration_tracker (
-- version is the global version of the migration. Note that we
-- intentionally don't set it as PRIMARY KEY as it'd auto increment on
-- SQLite and our sqlc workflow will replace it with an auto
-- incrementing SERIAL on Postgres too. UNIQUE achieves the same effect
-- without the auto increment.
version INTEGER UNIQUE NOT NULL,
-- migration_time is the timestamp at which the migration was run.
migration_time TIMESTAMP NOT NULL
);

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS invoice_payment_hashes;

View File

@ -0,0 +1,17 @@
-- invoice_payment_hashes table contains the hash of the invoices. This table
-- is used during KV to SQL invoice migration as in our KV representation we
-- don't have a mapping from hash to add index.
CREATE TABLE IF NOT EXISTS invoice_payment_hashes (
-- id represents is the key of the invoice in the KV store.
id INTEGER PRIMARY KEY,
-- add_index is the KV add index of the invoice.
add_index BIGINT NOT NULL,
-- hash is the payment hash for this invoice.
hash BLOB
);
-- Create an indexes on the add_index and hash columns to speed up lookups.
CREATE INDEX IF NOT EXISTS invoice_payment_hashes_add_index_idx ON invoice_payment_hashes(add_index);
CREATE INDEX IF NOT EXISTS invoice_payment_hashes_hash_idx ON invoice_payment_hashes(hash);

View File

@ -58,7 +58,7 @@ type InvoiceEvent struct {
}
type InvoiceEventType struct {
ID int32
ID int64
Description string
}
@ -87,7 +87,18 @@ type InvoiceHtlcCustomRecord struct {
HtlcID int64
}
type InvoicePaymentHash struct {
ID int64
AddIndex int64
Hash []byte
}
type InvoiceSequence struct {
Name string
CurrentValue int64
}
type MigrationTracker struct {
Version int32
MigrationTime time.Time
}

View File

@ -7,9 +7,11 @@ package sqlc
import (
"context"
"database/sql"
"time"
)
type Querier interface {
ClearKVInvoiceHashIndex(ctx context.Context) error
DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error)
FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error)
@ -17,19 +19,26 @@ type Querier interface {
FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error)
FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error)
GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error)
GetDatabaseVersion(ctx context.Context) (int32, error)
// This method may return more than one invoice if filter using multiple fields
// from different invoices. It is the caller's responsibility to ensure that
// we bubble up an error in those cases.
GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error)
GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error)
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error)
GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error)
GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error)
GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error)
GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error)
GetMigration(ctx context.Context, version int32) (time.Time, error)
InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error
InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error
InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error)
InsertInvoiceFeature(ctx context.Context, arg InsertInvoiceFeatureParams) error
InsertInvoiceHTLC(ctx context.Context, arg InsertInvoiceHTLCParams) (int64, error)
InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertInvoiceHTLCCustomRecordParams) error
InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error
InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error)
NextInvoiceSettleIndex(ctx context.Context) (int64, error)
OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error
OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error
@ -37,6 +46,8 @@ type Querier interface {
OnInvoiceCanceled(ctx context.Context, arg OnInvoiceCanceledParams) error
OnInvoiceCreated(ctx context.Context, arg OnInvoiceCreatedParams) error
OnInvoiceSettled(ctx context.Context, arg OnInvoiceSettledParams) error
SetKVInvoicePaymentHash(ctx context.Context, arg SetKVInvoicePaymentHashParams) error
SetMigration(ctx context.Context, arg SetMigrationParams) error
UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg UpdateAMPSubInvoiceHTLCPreimageParams) (sql.Result, error)
UpdateAMPSubInvoiceState(ctx context.Context, arg UpdateAMPSubInvoiceStateParams) error
UpdateInvoiceAmountPaid(ctx context.Context, arg UpdateInvoiceAmountPaidParams) (sql.Result, error)

View File

@ -65,3 +65,11 @@ SET preimage = $5
WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = (
SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4
);
-- name: InsertAMPSubInvoice :exec
INSERT INTO amp_sub_invoices (
set_id, state, created_at, settled_at, settle_index, invoice_id
) VALUES (
$1, $2, $3, $4, $5, $6
);

View File

@ -7,6 +7,16 @@ INSERT INTO invoices (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
) RETURNING id;
-- name: InsertMigratedInvoice :one
INSERT INTO invoices (
hash, preimage, settle_index, settled_at, memo, amount_msat, cltv_delta,
expiry, payment_addr, payment_request, payment_request_hash, state,
amount_paid_msat, is_amp, is_hodl, is_keysend, created_at
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
) RETURNING id;
-- name: InsertInvoiceFeature :exec
INSERT INTO invoice_features (
invoice_id, feature
@ -37,9 +47,6 @@ WHERE (
) AND (
i.hash = sqlc.narg('hash') OR
sqlc.narg('hash') IS NULL
) AND (
i.preimage = sqlc.narg('preimage') OR
sqlc.narg('preimage') IS NULL
) AND (
i.payment_addr = sqlc.narg('payment_addr') OR
sqlc.narg('payment_addr') IS NULL
@ -47,6 +54,11 @@ WHERE (
GROUP BY i.id
LIMIT 2;
-- name: GetInvoiceByHash :one
SELECT i.*
FROM invoices i
WHERE i.hash = $1;
-- name: GetInvoiceBySetID :many
SELECT i.*
FROM invoices i
@ -169,3 +181,23 @@ INSERT INTO invoice_htlc_custom_records (
SELECT ihcr.htlc_id, key, value
FROM invoice_htlcs ih JOIN invoice_htlc_custom_records ihcr ON ih.id=ihcr.htlc_id
WHERE ih.invoice_id = $1;
-- name: InsertKVInvoiceKeyAndAddIndex :exec
INSERT INTO invoice_payment_hashes (
id, add_index
) VALUES (
$1, $2
);
-- name: SetKVInvoicePaymentHash :exec
UPDATE invoice_payment_hashes
SET hash = $2
WHERE id = $1;
-- name: GetKVInvoicePaymentHashByAddIndex :one
SELECT hash
FROM invoice_payment_hashes
WHERE add_index = $1;
-- name: ClearKVInvoiceHashIndex :exec
DELETE FROM invoice_payment_hashes;

View File

@ -0,0 +1,21 @@
-- name: SetMigration :exec
INSERT INTO
migration_tracker (version, migration_time)
VALUES ($1, $2);
-- name: GetMigration :one
SELECT
migration_time
FROM
migration_tracker
WHERE
version = $1;
-- name: GetDatabaseVersion :one
SELECT
version
FROM
migration_tracker
ORDER BY
version DESC
LIMIT 1;

View File

@ -3,6 +3,7 @@
package sqldb
import (
"context"
"database/sql"
"fmt"
"net/url"
@ -27,13 +28,16 @@ const (
)
var (
// sqliteSchemaReplacements is a map of schema strings that need to be
// replaced for sqlite. This is needed because sqlite doesn't directly
// support the BIGINT type for primary keys, so we need to replace it
// with INTEGER.
sqliteSchemaReplacements = map[string]string{
"BIGINT PRIMARY KEY": "INTEGER PRIMARY KEY",
}
// sqliteSchemaReplacements maps schema strings to their SQLite
// compatible replacements. Currently, no replacements are needed as our
// SQL schema definition files are designed for SQLite compatibility.
sqliteSchemaReplacements = map[string]string{}
// Make sure SqliteStore implements the MigrationExecutor interface.
_ MigrationExecutor = (*SqliteStore)(nil)
// Make sure SqliteStore implements the DB interface.
_ DB = (*SqliteStore)(nil)
)
// SqliteStore is a database store implementation that uses a sqlite backend.
@ -102,6 +106,23 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
return nil, err
}
// Create the migration tracker table before starting migrations to
// ensure it can be used to track migration progress. Note that a
// corresponding SQLC migration also creates this table, making this
// operation a no-op in that context. Its purpose is to ensure
// compatibility with SQLC query generation.
migrationTrackerSQL := `
CREATE TABLE IF NOT EXISTS migration_tracker (
version INTEGER UNIQUE NOT NULL,
migration_time TIMESTAMP NOT NULL
);`
_, err = db.Exec(migrationTrackerSQL)
if err != nil {
return nil, fmt.Errorf("error creating migration tracker: %w",
err)
}
db.SetMaxOpenConns(defaultMaxConns)
db.SetMaxIdleConns(defaultMaxConns)
db.SetConnMaxLifetime(connIdleLifetime)
@ -115,16 +136,26 @@ func NewSqliteStore(cfg *SqliteConfig, dbPath string) (*SqliteStore, error) {
},
}
// Execute migrations unless configured to skip them.
if !cfg.SkipMigrations {
if err := s.ExecuteMigrations(TargetLatest); err != nil {
return nil, fmt.Errorf("error executing migrations: "+
"%w", err)
return s, nil
}
}
// GetBaseDB returns the underlying BaseDB instance for the SQLite store.
// It is a trivial helper method to comply with the sqldb.DB interface.
func (s *SqliteStore) GetBaseDB() *BaseDB {
return s.BaseDB
}
// ApplyAllMigrations applies both the SQLC and custom in-code migrations to the
// SQLite database.
func (s *SqliteStore) ApplyAllMigrations(ctx context.Context,
migrations []MigrationConfig) error {
// Execute migrations unless configured to skip them.
if s.cfg.SkipMigrations {
return nil
}
return s, nil
return ApplyMigrations(ctx, s.BaseDB, s, migrations)
}
// ExecuteMigrations runs migrations for the sqlite database, depending on the
@ -160,6 +191,10 @@ func NewTestSqliteDB(t *testing.T) *SqliteStore {
}, dbFileName)
require.NoError(t, err)
require.NoError(t, sqlDB.ApplyAllMigrations(
context.Background(), GetMigrations()),
)
t.Cleanup(func() {
require.NoError(t, sqlDB.DB.Close())
})