mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-01 18:50:09 +02:00
invoices: add migration code for a single invoice
This commit is contained in:
parent
43797d6be7
commit
708bed517d
@ -5,9 +5,13 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/graph/db/models"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/sqldb"
|
||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||
)
|
||||
|
||||
@ -126,3 +130,272 @@ func createInvoiceHashIndex(ctx context.Context, db kvdb.Backend,
|
||||
})
|
||||
}, func() {})
|
||||
}
|
||||
|
||||
// toInsertMigratedInvoiceParams creates the parameters for inserting a migrated
|
||||
// invoice into the SQL database. The parameters are derived from the original
|
||||
// invoice insert parameters.
|
||||
func toInsertMigratedInvoiceParams(
|
||||
params sqlc.InsertInvoiceParams) sqlc.InsertMigratedInvoiceParams {
|
||||
|
||||
return sqlc.InsertMigratedInvoiceParams{
|
||||
Hash: params.Hash,
|
||||
Preimage: params.Preimage,
|
||||
Memo: params.Memo,
|
||||
AmountMsat: params.AmountMsat,
|
||||
CltvDelta: params.CltvDelta,
|
||||
Expiry: params.Expiry,
|
||||
PaymentAddr: params.PaymentAddr,
|
||||
PaymentRequest: params.PaymentRequest,
|
||||
PaymentRequestHash: params.PaymentRequestHash,
|
||||
State: params.State,
|
||||
AmountPaidMsat: params.AmountPaidMsat,
|
||||
IsAmp: params.IsAmp,
|
||||
IsHodl: params.IsHodl,
|
||||
IsKeysend: params.IsKeysend,
|
||||
CreatedAt: params.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// MigrateSingleInvoice migrates a single invoice to the new SQL schema. Note
|
||||
// that perfect equality between the old and new schemas is not achievable, as
|
||||
// the invoice's add index cannot be mapped directly to its ID due to SQL’s
|
||||
// auto-incrementing primary key. The ID returned from the insert will instead
|
||||
// serve as the add index in the new schema.
|
||||
func MigrateSingleInvoice(ctx context.Context, tx SQLInvoiceQueries,
|
||||
invoice *Invoice, paymentHash lntypes.Hash) error {
|
||||
|
||||
insertInvoiceParams, err := makeInsertInvoiceParams(
|
||||
invoice, paymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Convert the insert invoice parameters to the migrated invoice insert
|
||||
// parameters.
|
||||
insertMigratedInvoiceParams := toInsertMigratedInvoiceParams(
|
||||
insertInvoiceParams,
|
||||
)
|
||||
|
||||
// If the invoice is settled, we'll also set the timestamp and the index
|
||||
// at which it was settled.
|
||||
if invoice.State == ContractSettled {
|
||||
if invoice.SettleIndex == 0 {
|
||||
return fmt.Errorf("settled invoice %s missing settle "+
|
||||
"index", paymentHash)
|
||||
}
|
||||
|
||||
if invoice.SettleDate.IsZero() {
|
||||
return fmt.Errorf("settled invoice %s missing settle "+
|
||||
"date", paymentHash)
|
||||
}
|
||||
|
||||
insertMigratedInvoiceParams.SettleIndex = sqldb.SQLInt64(
|
||||
invoice.SettleIndex,
|
||||
)
|
||||
insertMigratedInvoiceParams.SettledAt = sqldb.SQLTime(
|
||||
invoice.SettleDate.UTC(),
|
||||
)
|
||||
}
|
||||
|
||||
// First we need to insert the invoice itself so we can use the "add
|
||||
// index" which in this case is the auto incrementing primary key that
|
||||
// is returned from the insert.
|
||||
invoiceID, err := tx.InsertMigratedInvoice(
|
||||
ctx, insertMigratedInvoiceParams,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert invoice: %w", err)
|
||||
}
|
||||
|
||||
// Insert the invoice's features.
|
||||
for feature := range invoice.Terms.Features.Features() {
|
||||
params := sqlc.InsertInvoiceFeatureParams{
|
||||
InvoiceID: invoiceID,
|
||||
Feature: int32(feature),
|
||||
}
|
||||
|
||||
err := tx.InsertInvoiceFeature(ctx, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert invoice "+
|
||||
"feature(%v): %w", feature, err)
|
||||
}
|
||||
}
|
||||
|
||||
sqlHtlcIDs := make(map[models.CircuitKey]int64)
|
||||
|
||||
// Now insert the HTLCs of the invoice. We'll also keep track of the SQL
|
||||
// ID of each HTLC so we can use it when inserting the AMP sub invoices.
|
||||
for circuitKey, htlc := range invoice.Htlcs {
|
||||
htlcParams := sqlc.InsertInvoiceHTLCParams{
|
||||
HtlcID: int64(circuitKey.HtlcID),
|
||||
ChanID: strconv.FormatUint(
|
||||
circuitKey.ChanID.ToUint64(), 10,
|
||||
),
|
||||
AmountMsat: int64(htlc.Amt),
|
||||
AcceptHeight: int32(htlc.AcceptHeight),
|
||||
AcceptTime: htlc.AcceptTime.UTC(),
|
||||
ExpiryHeight: int32(htlc.Expiry),
|
||||
State: int16(htlc.State),
|
||||
InvoiceID: invoiceID,
|
||||
}
|
||||
|
||||
// Leave the MPP amount as NULL if the MPP total amount is zero.
|
||||
if htlc.MppTotalAmt != 0 {
|
||||
htlcParams.TotalMppMsat = sqldb.SQLInt64(
|
||||
int64(htlc.MppTotalAmt),
|
||||
)
|
||||
}
|
||||
|
||||
// Leave the resolve time as NULL if the HTLC is not resolved.
|
||||
if !htlc.ResolveTime.IsZero() {
|
||||
htlcParams.ResolveTime = sqldb.SQLTime(
|
||||
htlc.ResolveTime.UTC(),
|
||||
)
|
||||
}
|
||||
|
||||
sqlID, err := tx.InsertInvoiceHTLC(ctx, htlcParams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert invoice htlc: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
sqlHtlcIDs[circuitKey] = sqlID
|
||||
|
||||
// Store custom records.
|
||||
for key, value := range htlc.CustomRecords {
|
||||
err = tx.InsertInvoiceHTLCCustomRecord(
|
||||
ctx, sqlc.InsertInvoiceHTLCCustomRecordParams{
|
||||
Key: int64(key),
|
||||
Value: value,
|
||||
HtlcID: sqlID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !invoice.IsAMP() {
|
||||
return nil
|
||||
}
|
||||
|
||||
for setID, ampState := range invoice.AMPState {
|
||||
// Find the earliest HTLC of the AMP invoice, which will
|
||||
// be used as the creation date of this sub invoice.
|
||||
var createdAt time.Time
|
||||
for circuitKey := range ampState.InvoiceKeys {
|
||||
htlc := invoice.Htlcs[circuitKey]
|
||||
if createdAt.IsZero() {
|
||||
createdAt = htlc.AcceptTime.UTC()
|
||||
continue
|
||||
}
|
||||
|
||||
if createdAt.After(htlc.AcceptTime) {
|
||||
createdAt = htlc.AcceptTime.UTC()
|
||||
}
|
||||
}
|
||||
|
||||
params := sqlc.InsertAMPSubInvoiceParams{
|
||||
SetID: setID[:],
|
||||
State: int16(ampState.State),
|
||||
CreatedAt: createdAt,
|
||||
InvoiceID: invoiceID,
|
||||
}
|
||||
|
||||
if ampState.SettleIndex != 0 {
|
||||
if ampState.SettleDate.IsZero() {
|
||||
return fmt.Errorf("settled AMP sub invoice %x "+
|
||||
"missing settle date", setID)
|
||||
}
|
||||
|
||||
params.SettledAt = sqldb.SQLTime(
|
||||
ampState.SettleDate.UTC(),
|
||||
)
|
||||
|
||||
params.SettleIndex = sqldb.SQLInt64(
|
||||
ampState.SettleIndex,
|
||||
)
|
||||
}
|
||||
|
||||
err := tx.InsertAMPSubInvoice(ctx, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert AMP sub invoice: "+
|
||||
"%w", err)
|
||||
}
|
||||
|
||||
// Now we can add the AMP HTLCs to the database.
|
||||
for circuitKey := range ampState.InvoiceKeys {
|
||||
htlc := invoice.Htlcs[circuitKey]
|
||||
rootShare := htlc.AMP.Record.RootShare()
|
||||
|
||||
sqlHtlcID, ok := sqlHtlcIDs[circuitKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing htlc for AMP htlc: "+
|
||||
"%v", circuitKey)
|
||||
}
|
||||
|
||||
params := sqlc.InsertAMPSubInvoiceHTLCParams{
|
||||
InvoiceID: invoiceID,
|
||||
SetID: setID[:],
|
||||
HtlcID: sqlHtlcID,
|
||||
RootShare: rootShare[:],
|
||||
ChildIndex: int64(htlc.AMP.Record.ChildIndex()),
|
||||
Hash: htlc.AMP.Hash[:],
|
||||
}
|
||||
|
||||
if htlc.AMP.Preimage != nil {
|
||||
params.Preimage = htlc.AMP.Preimage[:]
|
||||
}
|
||||
|
||||
err = tx.InsertAMPSubInvoiceHTLC(ctx, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to insert AMP sub "+
|
||||
"invoice: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OverrideInvoiceTimeZone overrides the time zone of the invoice to the local
|
||||
// time zone and chops off the nanosecond part for comparison. This is needed
|
||||
// because KV database stores times as-is which as an unwanted side effect would
|
||||
// fail migration due to time comparison expecting both the original and
|
||||
// migrated invoices to be in the same local time zone and in microsecond
|
||||
// precision. Note that PostgreSQL stores times in microsecond precision while
|
||||
// SQLite can store times in nanosecond precision if using TEXT storage class.
|
||||
func OverrideInvoiceTimeZone(invoice *Invoice) {
|
||||
fixTime := func(t time.Time) time.Time {
|
||||
return t.In(time.Local).Truncate(time.Microsecond)
|
||||
}
|
||||
|
||||
invoice.CreationDate = fixTime(invoice.CreationDate)
|
||||
|
||||
if !invoice.SettleDate.IsZero() {
|
||||
invoice.SettleDate = fixTime(invoice.SettleDate)
|
||||
}
|
||||
|
||||
if invoice.IsAMP() {
|
||||
for setID, ampState := range invoice.AMPState {
|
||||
if ampState.SettleDate.IsZero() {
|
||||
continue
|
||||
}
|
||||
|
||||
ampState.SettleDate = fixTime(ampState.SettleDate)
|
||||
invoice.AMPState[setID] = ampState
|
||||
}
|
||||
}
|
||||
|
||||
for _, htlc := range invoice.Htlcs {
|
||||
if !htlc.AcceptTime.IsZero() {
|
||||
htlc.AcceptTime = fixTime(htlc.AcceptTime)
|
||||
}
|
||||
|
||||
if !htlc.ResolveTime.IsZero() {
|
||||
htlc.ResolveTime = fixTime(htlc.ResolveTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
421
invoices/sql_migration_test.go
Normal file
421
invoices/sql_migration_test.go
Normal file
@ -0,0 +1,421 @@
|
||||
package invoices
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"database/sql"
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/clock"
|
||||
"github.com/lightningnetwork/lnd/graph/db/models"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/sqldb"
|
||||
"github.com/stretchr/testify/require"
|
||||
"pgregory.net/rapid"
|
||||
)
|
||||
|
||||
var (
|
||||
// testHtlcIDSequence is a global counter for generating unique HTLC
|
||||
// IDs.
|
||||
testHtlcIDSequence uint64
|
||||
)
|
||||
|
||||
// randomString generates a random string of a given length using rapid.
|
||||
func randomStringRapid(t *rapid.T, length int) string {
|
||||
// Define the character set for the string.
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" //nolint:ll
|
||||
|
||||
// Generate a string by selecting random characters from the charset.
|
||||
runes := make([]rune, length)
|
||||
for i := range runes {
|
||||
// Draw a random index and use it to select a character from the
|
||||
// charset.
|
||||
index := rapid.IntRange(0, len(charset)-1).Draw(t, "charIndex")
|
||||
runes[i] = rune(charset[index])
|
||||
}
|
||||
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
// randTimeBetween generates a random time between min and max.
|
||||
func randTimeBetween(min, max time.Time) time.Time {
|
||||
var timeZones = []*time.Location{
|
||||
time.UTC,
|
||||
time.FixedZone("EST", -5*3600),
|
||||
time.FixedZone("MST", -7*3600),
|
||||
time.FixedZone("PST", -8*3600),
|
||||
time.FixedZone("CEST", 2*3600),
|
||||
}
|
||||
|
||||
// Ensure max is after min
|
||||
if max.Before(min) {
|
||||
min, max = max, min
|
||||
}
|
||||
|
||||
// Calculate the range in nanoseconds
|
||||
duration := max.Sub(min)
|
||||
randDuration := time.Duration(rand.Int63n(duration.Nanoseconds()))
|
||||
|
||||
// Generate the random time
|
||||
randomTime := min.Add(randDuration)
|
||||
|
||||
// Assign a random time zone
|
||||
randomTimeZone := timeZones[rand.Intn(len(timeZones))]
|
||||
|
||||
// Return the time in the random time zone
|
||||
return randomTime.In(randomTimeZone)
|
||||
}
|
||||
|
||||
// randTime generates a random time between 2009 and 2140.
|
||||
func randTime() time.Time {
|
||||
min := time.Date(2009, 1, 3, 0, 0, 0, 0, time.UTC)
|
||||
max := time.Date(2140, 1, 1, 0, 0, 0, 1000, time.UTC)
|
||||
|
||||
return randTimeBetween(min, max)
|
||||
}
|
||||
|
||||
func randInvoiceTime(invoice *Invoice) time.Time {
|
||||
return randTimeBetween(
|
||||
invoice.CreationDate,
|
||||
invoice.CreationDate.Add(invoice.Terms.Expiry),
|
||||
)
|
||||
}
|
||||
|
||||
// randHTLCRapid generates a random HTLC for an invoice using rapid to randomize
|
||||
// its parameters.
|
||||
func randHTLCRapid(t *rapid.T, invoice *Invoice, amt lnwire.MilliSatoshi) (
|
||||
models.CircuitKey, *InvoiceHTLC) {
|
||||
|
||||
htlc := &InvoiceHTLC{
|
||||
Amt: amt,
|
||||
AcceptHeight: rapid.Uint32Range(1, 999).Draw(t, "AcceptHeight"),
|
||||
AcceptTime: randInvoiceTime(invoice),
|
||||
Expiry: rapid.Uint32Range(1, 999).Draw(t, "Expiry"),
|
||||
}
|
||||
|
||||
// Set MPP total amount if MPP feature is enabled in the invoice.
|
||||
if invoice.Terms.Features.HasFeature(lnwire.MPPRequired) {
|
||||
htlc.MppTotalAmt = invoice.Terms.Value
|
||||
}
|
||||
|
||||
// Set the HTLC state and resolve time based on the invoice state.
|
||||
switch invoice.State {
|
||||
case ContractSettled:
|
||||
htlc.State = HtlcStateSettled
|
||||
htlc.ResolveTime = randInvoiceTime(invoice)
|
||||
|
||||
case ContractCanceled:
|
||||
htlc.State = HtlcStateCanceled
|
||||
htlc.ResolveTime = randInvoiceTime(invoice)
|
||||
|
||||
case ContractAccepted:
|
||||
htlc.State = HtlcStateAccepted
|
||||
}
|
||||
|
||||
// Add randomized custom records to the HTLC.
|
||||
htlc.CustomRecords = make(record.CustomSet)
|
||||
numRecords := rapid.IntRange(0, 5).Draw(t, "numRecords")
|
||||
for i := 0; i < numRecords; i++ {
|
||||
key := rapid.Uint64Range(
|
||||
record.CustomTypeStart, 1000+record.CustomTypeStart,
|
||||
).Draw(t, "customRecordKey")
|
||||
value := []byte(randomStringRapid(t, 10))
|
||||
htlc.CustomRecords[key] = value
|
||||
}
|
||||
|
||||
// Generate a unique HTLC ID and assign it to a channel ID.
|
||||
htlcID := atomic.AddUint64(&testHtlcIDSequence, 1)
|
||||
randChanID := lnwire.NewShortChanIDFromInt(htlcID % 5)
|
||||
|
||||
circuitKey := models.CircuitKey{
|
||||
ChanID: randChanID,
|
||||
HtlcID: htlcID,
|
||||
}
|
||||
|
||||
return circuitKey, htlc
|
||||
}
|
||||
|
||||
// generateInvoiceHTLCsRapid generates all HTLCs for an invoice, including AMP
|
||||
// HTLCs if applicable, using rapid for randomization of HTLC count and
|
||||
// distribution.
|
||||
func generateInvoiceHTLCsRapid(t *rapid.T, invoice *Invoice) {
|
||||
mpp := invoice.Terms.Features.HasFeature(lnwire.MPPRequired)
|
||||
|
||||
// Use rapid to determine the number of HTLCs based on invoice state and
|
||||
// MPP feature.
|
||||
numHTLCs := 1
|
||||
if invoice.State == ContractOpen {
|
||||
numHTLCs = 0
|
||||
} else if mpp {
|
||||
numHTLCs = rapid.IntRange(1, 10).Draw(t, "numHTLCs")
|
||||
}
|
||||
|
||||
total := invoice.Terms.Value
|
||||
|
||||
// Distribute the total amount across the HTLCs, adding any remainder to
|
||||
// the last HTLC.
|
||||
if numHTLCs > 0 {
|
||||
amt := total / lnwire.MilliSatoshi(numHTLCs)
|
||||
remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
|
||||
|
||||
for i := 0; i < numHTLCs; i++ {
|
||||
if i == numHTLCs-1 {
|
||||
// Add remainder to the last HTLC.
|
||||
amt += remainder
|
||||
}
|
||||
|
||||
// Generate an HTLC with a random circuit key and add it
|
||||
// to the invoice.
|
||||
circuitKey, htlc := randHTLCRapid(t, invoice, amt)
|
||||
invoice.Htlcs[circuitKey] = htlc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// generateAMPHtlcsRapid generates AMP HTLCs for an invoice using rapid to
|
||||
// randomize various parameters of the HTLCs in the AMP set.
|
||||
func generateAMPHtlcsRapid(t *rapid.T, invoice *Invoice) {
|
||||
// Randomly determine the number of AMP sets (1 to 5).
|
||||
numSetIDs := rapid.IntRange(1, 5).Draw(t, "numSetIDs")
|
||||
settledIdx := uint64(1)
|
||||
|
||||
for i := 0; i < numSetIDs; i++ {
|
||||
var setID SetID
|
||||
_, err := crand.Read(setID[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Determine the number of HTLCs in this set (1 to 5).
|
||||
numHTLCs := rapid.IntRange(1, 5).Draw(t, "numHTLCs")
|
||||
total := invoice.Terms.Value
|
||||
invoiceKeys := make(map[CircuitKey]struct{})
|
||||
|
||||
// Calculate the amount per HTLC and account for remainder in
|
||||
// the final HTLC.
|
||||
amt := total / lnwire.MilliSatoshi(numHTLCs)
|
||||
remainder := total - amt*lnwire.MilliSatoshi(numHTLCs)
|
||||
|
||||
var htlcState HtlcState
|
||||
for j := 0; j < numHTLCs; j++ {
|
||||
if j == numHTLCs-1 {
|
||||
amt += remainder
|
||||
}
|
||||
|
||||
// Generate HTLC with randomized parameters.
|
||||
circuitKey, htlc := randHTLCRapid(t, invoice, amt)
|
||||
htlcState = htlc.State
|
||||
|
||||
var (
|
||||
rootShare, hash [32]byte
|
||||
preimage lntypes.Preimage
|
||||
)
|
||||
|
||||
// Randomize AMP data fields.
|
||||
_, err := crand.Read(rootShare[:])
|
||||
require.NoError(t, err)
|
||||
_, err = crand.Read(hash[:])
|
||||
require.NoError(t, err)
|
||||
_, err = crand.Read(preimage[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
record := record.NewAMP(rootShare, setID, uint32(j))
|
||||
|
||||
htlc.AMP = &InvoiceHtlcAMPData{
|
||||
Record: *record,
|
||||
Hash: hash,
|
||||
Preimage: &preimage,
|
||||
}
|
||||
|
||||
invoice.Htlcs[circuitKey] = htlc
|
||||
invoiceKeys[circuitKey] = struct{}{}
|
||||
}
|
||||
|
||||
ampState := InvoiceStateAMP{
|
||||
State: htlcState,
|
||||
InvoiceKeys: invoiceKeys,
|
||||
}
|
||||
if htlcState == HtlcStateSettled {
|
||||
ampState.SettleIndex = settledIdx
|
||||
ampState.SettleDate = randInvoiceTime(invoice)
|
||||
settledIdx++
|
||||
}
|
||||
|
||||
// Set the total amount paid if the AMP set is not canceled.
|
||||
if htlcState != HtlcStateCanceled {
|
||||
ampState.AmtPaid = invoice.Terms.Value
|
||||
}
|
||||
|
||||
invoice.AMPState[setID] = ampState
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrateSingleInvoiceRapid tests the migration of single invoices with
|
||||
// random data variations using rapid. This test generates a random invoice
|
||||
// configuration and ensures successful migration.
|
||||
//
|
||||
// NOTE: This test may need to be changed if the Invoice or any of the related
|
||||
// types are modified.
|
||||
func TestMigrateSingleInvoiceRapid(t *testing.T) {
|
||||
// Create a shared Postgres instance for efficient testing.
|
||||
pgFixture := sqldb.NewTestPgFixture(
|
||||
t, sqldb.DefaultPostgresFixtureLifetime,
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
pgFixture.TearDown(t)
|
||||
})
|
||||
|
||||
makeSQLDB := func(t *testing.T, sqlite bool) *SQLStore {
|
||||
var db *sqldb.BaseDB
|
||||
if sqlite {
|
||||
db = sqldb.NewTestSqliteDB(t).BaseDB
|
||||
} else {
|
||||
db = sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
|
||||
}
|
||||
|
||||
executor := sqldb.NewTransactionExecutor(
|
||||
db, func(tx *sql.Tx) SQLInvoiceQueries {
|
||||
return db.WithTx(tx)
|
||||
},
|
||||
)
|
||||
|
||||
testClock := clock.NewTestClock(time.Unix(1, 0))
|
||||
|
||||
return NewSQLStore(executor, testClock)
|
||||
}
|
||||
|
||||
// Define property-based test using rapid.
|
||||
rapid.Check(t, func(rt *rapid.T) {
|
||||
// Randomized feature flags for MPP and AMP.
|
||||
mpp := rapid.Bool().Draw(rt, "mpp")
|
||||
amp := rapid.Bool().Draw(rt, "amp")
|
||||
|
||||
for _, sqlite := range []bool{true, false} {
|
||||
store := makeSQLDB(t, sqlite)
|
||||
testMigrateSingleInvoiceRapid(rt, store, mpp, amp)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// testMigrateSingleInvoiceRapid is the primary function for the migration of a
|
||||
// single invoice with random data in a rapid-based test setup.
|
||||
func testMigrateSingleInvoiceRapid(t *rapid.T, store *SQLStore, mpp bool,
|
||||
amp bool) {
|
||||
|
||||
ctxb := context.Background()
|
||||
invoices := make(map[lntypes.Hash]*Invoice)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
invoice := generateTestInvoiceRapid(t, mpp, amp)
|
||||
var hash lntypes.Hash
|
||||
_, err := crand.Read(hash[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
invoices[hash] = invoice
|
||||
}
|
||||
|
||||
var ops SQLInvoiceQueriesTxOptions
|
||||
err := store.db.ExecTx(ctxb, &ops, func(tx SQLInvoiceQueries) error {
|
||||
for hash, invoice := range invoices {
|
||||
err := MigrateSingleInvoice(ctxb, tx, invoice, hash)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, func() {})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Fetch and compare each migrated invoice from the store with the
|
||||
// original.
|
||||
for hash, invoice := range invoices {
|
||||
sqlInvoice, err := store.LookupInvoice(
|
||||
ctxb, InvoiceRefByHash(hash),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
invoice.AddIndex = sqlInvoice.AddIndex
|
||||
|
||||
OverrideInvoiceTimeZone(invoice)
|
||||
OverrideInvoiceTimeZone(&sqlInvoice)
|
||||
|
||||
require.Equal(t, *invoice, sqlInvoice)
|
||||
}
|
||||
}
|
||||
|
||||
// generateTestInvoiceRapid generates a random invoice with variations based on
|
||||
// mpp and amp flags.
|
||||
func generateTestInvoiceRapid(t *rapid.T, mpp bool, amp bool) *Invoice {
|
||||
var preimage lntypes.Preimage
|
||||
_, err := crand.Read(preimage[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
terms := ContractTerm{
|
||||
FinalCltvDelta: rapid.Int32Range(1, 1000).Draw(
|
||||
t, "FinalCltvDelta",
|
||||
),
|
||||
Expiry: time.Duration(
|
||||
rapid.IntRange(1, 4444).Draw(t, "Expiry"),
|
||||
) * time.Minute,
|
||||
PaymentPreimage: &preimage,
|
||||
Value: lnwire.MilliSatoshi(
|
||||
rapid.Int64Range(1, 9999999).Draw(t, "Value"),
|
||||
),
|
||||
PaymentAddr: [32]byte{},
|
||||
Features: lnwire.EmptyFeatureVector(),
|
||||
}
|
||||
|
||||
if amp {
|
||||
terms.Features.Set(lnwire.AMPRequired)
|
||||
} else if mpp {
|
||||
terms.Features.Set(lnwire.MPPRequired)
|
||||
}
|
||||
|
||||
created := randTime()
|
||||
|
||||
const maxContractState = 3
|
||||
state := ContractState(
|
||||
rapid.IntRange(0, maxContractState).Draw(t, "ContractState"),
|
||||
)
|
||||
var (
|
||||
settled time.Time
|
||||
settleIndex uint64
|
||||
)
|
||||
if state == ContractSettled {
|
||||
settled = randTimeBetween(created, created.Add(terms.Expiry))
|
||||
settleIndex = rapid.Uint64Range(1, 999).Draw(t, "SettleIndex")
|
||||
}
|
||||
|
||||
invoice := &Invoice{
|
||||
Memo: []byte(randomStringRapid(t, 10)),
|
||||
PaymentRequest: []byte(
|
||||
randomStringRapid(t, MaxPaymentRequestSize),
|
||||
),
|
||||
CreationDate: created,
|
||||
SettleDate: settled,
|
||||
Terms: terms,
|
||||
AddIndex: 0,
|
||||
SettleIndex: settleIndex,
|
||||
State: state,
|
||||
AMPState: make(map[SetID]InvoiceStateAMP),
|
||||
HodlInvoice: rapid.Bool().Draw(t, "HodlInvoice"),
|
||||
}
|
||||
|
||||
invoice.Htlcs = make(map[models.CircuitKey]*InvoiceHTLC)
|
||||
|
||||
if invoice.IsAMP() {
|
||||
generateAMPHtlcsRapid(t, invoice)
|
||||
} else {
|
||||
generateInvoiceHTLCsRapid(t, invoice)
|
||||
}
|
||||
|
||||
for _, htlc := range invoice.Htlcs {
|
||||
if htlc.State == HtlcStateSettled {
|
||||
invoice.AmtPaid += htlc.Amt
|
||||
}
|
||||
}
|
||||
|
||||
return invoice
|
||||
}
|
@ -32,6 +32,10 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
||||
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
|
||||
error)
|
||||
|
||||
// TODO(bhandras): remove this once migrations have been separated out.
|
||||
InsertMigratedInvoice(ctx context.Context,
|
||||
arg sqlc.InsertMigratedInvoiceParams) (int64, error)
|
||||
|
||||
InsertInvoiceFeature(ctx context.Context,
|
||||
arg sqlc.InsertInvoiceFeatureParams) error
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user