mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-07-01 19:10:59 +02:00
sqldb+invoices: move SQL invoice store impl to invoices package
This commit is contained in:
@ -1053,12 +1053,12 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
|
|||||||
|
|
||||||
executor := sqldb.NewTransactionExecutor(
|
executor := sqldb.NewTransactionExecutor(
|
||||||
dbs.NativeSQLStore,
|
dbs.NativeSQLStore,
|
||||||
func(tx *sql.Tx) sqldb.InvoiceQueries {
|
func(tx *sql.Tx) invoices.SQLInvoiceQueries {
|
||||||
return dbs.NativeSQLStore.WithTx(tx)
|
return dbs.NativeSQLStore.WithTx(tx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
dbs.InvoiceDB = sqldb.NewInvoiceStore(
|
dbs.InvoiceDB = invoices.NewSQLStore(
|
||||||
executor, clock.NewDefaultClock(),
|
executor, clock.NewDefaultClock(),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
@ -136,14 +136,14 @@ func TestInvoiceRegistry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
executor := sqldb.NewTransactionExecutor(
|
executor := sqldb.NewTransactionExecutor(
|
||||||
db, func(tx *sql.Tx) sqldb.InvoiceQueries {
|
db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries {
|
||||||
return db.WithTx(tx)
|
return db.WithTx(tx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
testClock := clock.NewTestClock(testNow)
|
testClock := clock.NewTestClock(testNow)
|
||||||
|
|
||||||
return sqldb.NewInvoiceStore(executor, testClock), testClock
|
return invpkg.NewSQLStore(executor, testClock), testClock
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range testList {
|
for _, test := range testList {
|
||||||
|
@ -240,14 +240,14 @@ func TestInvoices(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
executor := sqldb.NewTransactionExecutor(
|
executor := sqldb.NewTransactionExecutor(
|
||||||
db, func(tx *sql.Tx) sqldb.InvoiceQueries {
|
db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries {
|
||||||
return db.WithTx(tx)
|
return db.WithTx(tx)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
testClock := clock.NewTestClock(testNow)
|
testClock := clock.NewTestClock(testNow)
|
||||||
|
|
||||||
return sqldb.NewInvoiceStore(executor, testClock)
|
return invpkg.NewSQLStore(executor, testClock)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range testList {
|
for _, test := range testList {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package sqldb
|
package invoices
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -12,10 +12,10 @@ import (
|
|||||||
|
|
||||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||||
"github.com/lightningnetwork/lnd/clock"
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
invpkg "github.com/lightningnetwork/lnd/invoices"
|
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/record"
|
"github.com/lightningnetwork/lnd/record"
|
||||||
|
"github.com/lightningnetwork/lnd/sqldb"
|
||||||
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
"github.com/lightningnetwork/lnd/sqldb/sqlc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -25,9 +25,9 @@ const (
|
|||||||
queryPaginationLimit = 100
|
queryPaginationLimit = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
// InvoiceQueries is an interface that defines the set of operations that can be
|
// SQLInvoiceQueries is an interface that defines the set of operations that can
|
||||||
// executed against the invoice database.
|
// be executed against the invoice SQL database.
|
||||||
type InvoiceQueries interface { //nolint:interfacebloat
|
type SQLInvoiceQueries interface { //nolint:interfacebloat
|
||||||
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
|
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
|
||||||
error)
|
error)
|
||||||
|
|
||||||
@ -117,11 +117,11 @@ type InvoiceQueries interface { //nolint:interfacebloat
|
|||||||
arg sqlc.OnAMPSubInvoiceSettledParams) error
|
arg sqlc.OnAMPSubInvoiceSettledParams) error
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ invpkg.InvoiceDB = (*InvoiceStore)(nil)
|
var _ InvoiceDB = (*SQLStore)(nil)
|
||||||
|
|
||||||
// InvoiceQueriesTxOptions defines the set of db txn options the InvoiceQueries
|
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
|
||||||
// understands.
|
// SQLInvoiceQueries understands.
|
||||||
type InvoiceQueriesTxOptions struct {
|
type SQLInvoiceQueriesTxOptions struct {
|
||||||
// readOnly governs if a read only transaction is needed or not.
|
// readOnly governs if a read only transaction is needed or not.
|
||||||
readOnly bool
|
readOnly bool
|
||||||
}
|
}
|
||||||
@ -129,37 +129,37 @@ type InvoiceQueriesTxOptions struct {
|
|||||||
// ReadOnly returns true if the transaction should be read only.
|
// ReadOnly returns true if the transaction should be read only.
|
||||||
//
|
//
|
||||||
// NOTE: This implements the TxOptions.
|
// NOTE: This implements the TxOptions.
|
||||||
func (a *InvoiceQueriesTxOptions) ReadOnly() bool {
|
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
|
||||||
return a.readOnly
|
return a.readOnly
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInvoiceQueryReadTx creates a new read transaction option set.
|
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
|
||||||
func NewInvoiceQueryReadTx() InvoiceQueriesTxOptions {
|
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
|
||||||
return InvoiceQueriesTxOptions{
|
return SQLInvoiceQueriesTxOptions{
|
||||||
readOnly: true,
|
readOnly: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchedInvoiceQueries is a version of the InvoiceQueries that's capable of
|
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
|
||||||
// batched database operations.
|
// of batched database operations.
|
||||||
type BatchedInvoiceQueries interface {
|
type BatchedSQLInvoiceQueries interface {
|
||||||
InvoiceQueries
|
SQLInvoiceQueries
|
||||||
|
|
||||||
BatchedTx[InvoiceQueries]
|
sqldb.BatchedTx[SQLInvoiceQueries]
|
||||||
}
|
}
|
||||||
|
|
||||||
// InvoiceStore represents a storage backend.
|
// SQLStore represents a storage backend.
|
||||||
type InvoiceStore struct {
|
type SQLStore struct {
|
||||||
db BatchedInvoiceQueries
|
db BatchedSQLInvoiceQueries
|
||||||
clock clock.Clock
|
clock clock.Clock
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInvoiceStore creates a new InvoiceStore instance given a open
|
// NewSQLStore creates a new SQLStore instance given a open
|
||||||
// BatchedInvoiceQueries storage backend.
|
// BatchedSQLInvoiceQueries storage backend.
|
||||||
func NewInvoiceStore(db BatchedInvoiceQueries,
|
func NewSQLStore(db BatchedSQLInvoiceQueries,
|
||||||
clock clock.Clock) *InvoiceStore {
|
clock clock.Clock) *SQLStore {
|
||||||
|
|
||||||
return &InvoiceStore{
|
return &SQLStore{
|
||||||
db: db,
|
db: db,
|
||||||
clock: clock,
|
clock: clock,
|
||||||
}
|
}
|
||||||
@ -171,17 +171,17 @@ func NewInvoiceStore(db BatchedInvoiceQueries,
|
|||||||
// duplicate payment hashes.
|
// duplicate payment hashes.
|
||||||
//
|
//
|
||||||
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
|
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
|
||||||
func (i *InvoiceStore) AddInvoice(ctx context.Context,
|
func (i *SQLStore) AddInvoice(ctx context.Context,
|
||||||
newInvoice *invpkg.Invoice, paymentHash lntypes.Hash) (uint64, error) {
|
newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
|
||||||
|
|
||||||
// Make sure this is a valid invoice before trying to store it in our
|
// Make sure this is a valid invoice before trying to store it in our
|
||||||
// DB.
|
// DB.
|
||||||
if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
|
if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
writeTxOpts InvoiceQueriesTxOptions
|
writeTxOpts SQLInvoiceQueriesTxOptions
|
||||||
invoiceID int64
|
invoiceID int64
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -193,16 +193,18 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
|
|||||||
paymentRequestHash = h.Sum(nil)
|
paymentRequestHash = h.Sum(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := i.db.ExecTx(ctx, &writeTxOpts, func(db InvoiceQueries) error {
|
err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
|
||||||
params := sqlc.InsertInvoiceParams{
|
params := sqlc.InsertInvoiceParams{
|
||||||
Hash: paymentHash[:],
|
Hash: paymentHash[:],
|
||||||
Memo: SQLStr(string(newInvoice.Memo)),
|
Memo: sqldb.SQLStr(string(newInvoice.Memo)),
|
||||||
AmountMsat: int64(newInvoice.Terms.Value),
|
AmountMsat: int64(newInvoice.Terms.Value),
|
||||||
// Note: BOLT12 invoices don't have a final cltv delta.
|
// Note: BOLT12 invoices don't have a final cltv delta.
|
||||||
CltvDelta: SQLInt32(newInvoice.Terms.FinalCltvDelta),
|
CltvDelta: sqldb.SQLInt32(
|
||||||
Expiry: int32(newInvoice.Terms.Expiry),
|
newInvoice.Terms.FinalCltvDelta,
|
||||||
|
),
|
||||||
|
Expiry: int32(newInvoice.Terms.Expiry),
|
||||||
// Note: keysend invoices don't have a payment request.
|
// Note: keysend invoices don't have a payment request.
|
||||||
PaymentRequest: SQLStr(string(
|
PaymentRequest: sqldb.SQLStr(string(
|
||||||
newInvoice.PaymentRequest),
|
newInvoice.PaymentRequest),
|
||||||
),
|
),
|
||||||
PaymentRequestHash: paymentRequestHash,
|
PaymentRequestHash: paymentRequestHash,
|
||||||
@ -218,7 +220,7 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
|
|||||||
// HODL invoices.
|
// HODL invoices.
|
||||||
if newInvoice.Terms.PaymentPreimage != nil {
|
if newInvoice.Terms.PaymentPreimage != nil {
|
||||||
preimage := *newInvoice.Terms.PaymentPreimage
|
preimage := *newInvoice.Terms.PaymentPreimage
|
||||||
if preimage == invpkg.UnknownPreimage {
|
if preimage == UnknownPreimage {
|
||||||
return errors.New("cannot use all-zeroes " +
|
return errors.New("cannot use all-zeroes " +
|
||||||
"preimage")
|
"preimage")
|
||||||
}
|
}
|
||||||
@ -226,7 +228,7 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Some non MPP payments may have the default (invalid) value.
|
// Some non MPP payments may have the default (invalid) value.
|
||||||
if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
|
if newInvoice.Terms.PaymentAddr != BlankPayAddr {
|
||||||
params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
|
params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -259,11 +261,11 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mappedSQLErr := MapSQLError(err)
|
mappedSQLErr := sqldb.MapSQLError(err)
|
||||||
var uniqueConstraintErr *ErrSQLUniqueConstraintViolation
|
var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
|
||||||
if errors.As(mappedSQLErr, &uniqueConstraintErr) {
|
if errors.As(mappedSQLErr, &uniqueConstraintErr) {
|
||||||
// Add context to unique constraint errors.
|
// Add context to unique constraint errors.
|
||||||
return 0, invpkg.ErrDuplicateInvoice
|
return 0, ErrDuplicateInvoice
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, fmt.Errorf("unable to add invoice(%v): %w",
|
return 0, fmt.Errorf("unable to add invoice(%v): %w",
|
||||||
@ -277,15 +279,15 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
|
|||||||
|
|
||||||
// fetchInvoice fetches the common invoice data and the AMP state for the
|
// fetchInvoice fetches the common invoice data and the AMP state for the
|
||||||
// invoice with the given reference.
|
// invoice with the given reference.
|
||||||
func (i *InvoiceStore) fetchInvoice(ctx context.Context,
|
func (i *SQLStore) fetchInvoice(ctx context.Context,
|
||||||
db InvoiceQueries, ref invpkg.InvoiceRef) (*invpkg.Invoice, error) {
|
db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) {
|
||||||
|
|
||||||
if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
|
if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
|
||||||
return nil, invpkg.ErrInvoiceNotFound
|
return nil, ErrInvoiceNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
invoice *invpkg.Invoice
|
invoice *Invoice
|
||||||
params sqlc.GetInvoiceParams
|
params sqlc.GetInvoiceParams
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -300,7 +302,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
|
|||||||
// all. Only allow lookups for payment address if it is not a blank
|
// all. Only allow lookups for payment address if it is not a blank
|
||||||
// payment address, which is a special-cased value for legacy keysend
|
// payment address, which is a special-cased value for legacy keysend
|
||||||
// invoices.
|
// invoices.
|
||||||
if ref.PayAddr() != nil && *ref.PayAddr() != invpkg.BlankPayAddr {
|
if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
|
||||||
params.PaymentAddr = ref.PayAddr()[:]
|
params.PaymentAddr = ref.PayAddr()[:]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -313,7 +315,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
|
|||||||
rows, err := db.GetInvoice(ctx, params)
|
rows, err := db.GetInvoice(ctx, params)
|
||||||
switch {
|
switch {
|
||||||
case len(rows) == 0:
|
case len(rows) == 0:
|
||||||
return nil, invpkg.ErrInvoiceNotFound
|
return nil, ErrInvoiceNotFound
|
||||||
|
|
||||||
case len(rows) > 1:
|
case len(rows) > 1:
|
||||||
// In case the reference is ambiguous, meaning it matches more
|
// In case the reference is ambiguous, meaning it matches more
|
||||||
@ -333,12 +335,12 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
|
|||||||
// Now that we got the invoice itself, fetch the HTLCs as requested by
|
// Now that we got the invoice itself, fetch the HTLCs as requested by
|
||||||
// the modifier.
|
// the modifier.
|
||||||
switch ref.Modifier() {
|
switch ref.Modifier() {
|
||||||
case invpkg.DefaultModifier:
|
case DefaultModifier:
|
||||||
// By default we'll fetch all AMP HTLCs.
|
// By default we'll fetch all AMP HTLCs.
|
||||||
setID = nil
|
setID = nil
|
||||||
fetchAmpHtlcs = true
|
fetchAmpHtlcs = true
|
||||||
|
|
||||||
case invpkg.HtlcSetOnlyModifier:
|
case HtlcSetOnlyModifier:
|
||||||
// In this case we'll fetch all AMP HTLCs for the
|
// In this case we'll fetch all AMP HTLCs for the
|
||||||
// specified set id.
|
// specified set id.
|
||||||
if ref.SetID() == nil {
|
if ref.SetID() == nil {
|
||||||
@ -349,7 +351,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
|
|||||||
setID = ref.SetID()
|
setID = ref.SetID()
|
||||||
fetchAmpHtlcs = true
|
fetchAmpHtlcs = true
|
||||||
|
|
||||||
case invpkg.HtlcSetBlankModifier:
|
case HtlcSetBlankModifier:
|
||||||
// No need to fetch any HTLCs.
|
// No need to fetch any HTLCs.
|
||||||
setID = nil
|
setID = nil
|
||||||
fetchAmpHtlcs = false
|
fetchAmpHtlcs = false
|
||||||
@ -377,9 +379,9 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
|
|||||||
// well.
|
// well.
|
||||||
//
|
//
|
||||||
//nolint:funlen
|
//nolint:funlen
|
||||||
func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
|
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
|
||||||
setID *[32]byte, fetchHtlcs bool) (invpkg.AMPInvoiceState,
|
setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
|
||||||
invpkg.HTLCSet, error) {
|
HTLCSet, error) {
|
||||||
|
|
||||||
var paramSetID []byte
|
var paramSetID []byte
|
||||||
if setID != nil {
|
if setID != nil {
|
||||||
@ -398,7 +400,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ampState := make(map[invpkg.SetID]invpkg.InvoiceStateAMP)
|
ampState := make(map[SetID]InvoiceStateAMP)
|
||||||
for _, row := range ampInvoiceRows {
|
for _, row := range ampInvoiceRows {
|
||||||
var rowSetID [32]byte
|
var rowSetID [32]byte
|
||||||
|
|
||||||
@ -413,8 +415,8 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
copy(rowSetID[:], row.SetID)
|
copy(rowSetID[:], row.SetID)
|
||||||
ampState[rowSetID] = invpkg.InvoiceStateAMP{
|
ampState[rowSetID] = InvoiceStateAMP{
|
||||||
State: invpkg.HtlcState(row.State),
|
State: HtlcState(row.State),
|
||||||
SettleIndex: uint64(row.SettleIndex.Int64),
|
SettleIndex: uint64(row.SettleIndex.Int64),
|
||||||
SettleDate: settleDate,
|
SettleDate: settleDate,
|
||||||
InvoiceKeys: make(map[models.CircuitKey]struct{}),
|
InvoiceKeys: make(map[models.CircuitKey]struct{}),
|
||||||
@ -457,7 +459,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ampHtlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
|
ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
|
||||||
for _, row := range ampHtlcRows {
|
for _, row := range ampHtlcRows {
|
||||||
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
|
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -473,17 +475,17 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
|
|||||||
|
|
||||||
htlcID := uint64(row.HtlcID)
|
htlcID := uint64(row.HtlcID)
|
||||||
|
|
||||||
circuitKey := invpkg.CircuitKey{
|
circuitKey := CircuitKey{
|
||||||
ChanID: chanID,
|
ChanID: chanID,
|
||||||
HtlcID: htlcID,
|
HtlcID: htlcID,
|
||||||
}
|
}
|
||||||
|
|
||||||
htlc := &invpkg.InvoiceHTLC{
|
htlc := &InvoiceHTLC{
|
||||||
Amt: lnwire.MilliSatoshi(row.AmountMsat),
|
Amt: lnwire.MilliSatoshi(row.AmountMsat),
|
||||||
AcceptHeight: uint32(row.AcceptHeight),
|
AcceptHeight: uint32(row.AcceptHeight),
|
||||||
AcceptTime: row.AcceptTime.Local(),
|
AcceptTime: row.AcceptTime.Local(),
|
||||||
Expiry: uint32(row.ExpiryHeight),
|
Expiry: uint32(row.ExpiryHeight),
|
||||||
State: invpkg.HtlcState(row.State),
|
State: HtlcState(row.State),
|
||||||
}
|
}
|
||||||
|
|
||||||
if row.TotalMppMsat.Valid {
|
if row.TotalMppMsat.Valid {
|
||||||
@ -522,7 +524,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
|
|||||||
rootShare, setID, uint32(row.ChildIndex),
|
rootShare, setID, uint32(row.ChildIndex),
|
||||||
)
|
)
|
||||||
|
|
||||||
htlc.AMP = &invpkg.InvoiceHtlcAMPData{
|
htlc.AMP = &InvoiceHtlcAMPData{
|
||||||
Record: *ampRecord,
|
Record: *ampRecord,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -564,7 +566,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
|
|||||||
|
|
||||||
invoiceKeys[key] = struct{}{}
|
invoiceKeys[key] = struct{}{}
|
||||||
|
|
||||||
if htlc.State != invpkg.HtlcStateCanceled { //nolint: lll
|
if htlc.State != HtlcStateCanceled { //nolint: lll
|
||||||
amtPaid += htlc.Amt
|
amtPaid += htlc.Amt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -584,22 +586,22 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
|
|||||||
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
|
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
|
||||||
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
|
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
|
||||||
// error.
|
// error.
|
||||||
func (i *InvoiceStore) LookupInvoice(ctx context.Context,
|
func (i *SQLStore) LookupInvoice(ctx context.Context,
|
||||||
ref invpkg.InvoiceRef) (invpkg.Invoice, error) {
|
ref InvoiceRef) (Invoice, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
invoice *invpkg.Invoice
|
invoice *Invoice
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
readTxOpt := NewInvoiceQueryReadTx()
|
readTxOpt := NewSQLInvoiceQueryReadTx()
|
||||||
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
|
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
|
||||||
invoice, err = i.fetchInvoice(ctx, db, ref)
|
invoice, err = i.fetchInvoice(ctx, db, ref)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if txErr != nil {
|
if txErr != nil {
|
||||||
return invpkg.Invoice{}, txErr
|
return Invoice{}, txErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return *invoice, nil
|
return *invoice, nil
|
||||||
@ -608,14 +610,14 @@ func (i *InvoiceStore) LookupInvoice(ctx context.Context,
|
|||||||
// FetchPendingInvoices returns all the invoices that are currently in a
|
// FetchPendingInvoices returns all the invoices that are currently in a
|
||||||
// "pending" state. An invoice is pending if it has been created but not yet
|
// "pending" state. An invoice is pending if it has been created but not yet
|
||||||
// settled or canceled.
|
// settled or canceled.
|
||||||
func (i *InvoiceStore) FetchPendingInvoices(ctx context.Context) (
|
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
|
||||||
map[lntypes.Hash]invpkg.Invoice, error) {
|
map[lntypes.Hash]Invoice, error) {
|
||||||
|
|
||||||
var invoices map[lntypes.Hash]invpkg.Invoice
|
var invoices map[lntypes.Hash]Invoice
|
||||||
|
|
||||||
readTxOpt := NewInvoiceQueryReadTx()
|
readTxOpt := NewSQLInvoiceQueryReadTx()
|
||||||
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
|
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
|
||||||
invoices = make(map[lntypes.Hash]invpkg.Invoice)
|
invoices = make(map[lntypes.Hash]Invoice)
|
||||||
limit := queryPaginationLimit
|
limit := queryPaginationLimit
|
||||||
|
|
||||||
return queryWithLimit(func(offset int) (int, error) {
|
return queryWithLimit(func(offset int) (int, error) {
|
||||||
@ -661,24 +663,24 @@ func (i *InvoiceStore) FetchPendingInvoices(ctx context.Context) (
|
|||||||
//
|
//
|
||||||
// NOTE: The index starts from 1. As a result we enforce that specifying a value
|
// NOTE: The index starts from 1. As a result we enforce that specifying a value
|
||||||
// below the starting index value is a noop.
|
// below the starting index value is a noop.
|
||||||
func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
|
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
|
||||||
[]invpkg.Invoice, error) {
|
[]Invoice, error) {
|
||||||
|
|
||||||
var invoices []invpkg.Invoice
|
var invoices []Invoice
|
||||||
|
|
||||||
if idx == 0 {
|
if idx == 0 {
|
||||||
return invoices, nil
|
return invoices, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
readTxOpt := NewInvoiceQueryReadTx()
|
readTxOpt := NewSQLInvoiceQueryReadTx()
|
||||||
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
|
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
|
||||||
invoices = nil
|
invoices = nil
|
||||||
settleIdx := idx
|
settleIdx := idx
|
||||||
limit := queryPaginationLimit
|
limit := queryPaginationLimit
|
||||||
|
|
||||||
err := queryWithLimit(func(offset int) (int, error) {
|
err := queryWithLimit(func(offset int) (int, error) {
|
||||||
params := sqlc.FilterInvoicesParams{
|
params := sqlc.FilterInvoicesParams{
|
||||||
SettleIndexGet: SQLInt64(settleIdx + 1),
|
SettleIndexGet: sqldb.SQLInt64(settleIdx + 1),
|
||||||
NumLimit: int32(limit),
|
NumLimit: int32(limit),
|
||||||
NumOffset: int32(offset),
|
NumOffset: int32(offset),
|
||||||
}
|
}
|
||||||
@ -715,7 +717,7 @@ func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
|
|||||||
// the provided index.
|
// the provided index.
|
||||||
ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
|
ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
|
||||||
ctx, sqlc.FetchSettledAMPSubInvoicesParams{
|
ctx, sqlc.FetchSettledAMPSubInvoicesParams{
|
||||||
SettleIndexGet: SQLInt64(idx + 1),
|
SettleIndexGet: sqldb.SQLInt64(idx + 1),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -775,24 +777,24 @@ func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
|
|||||||
//
|
//
|
||||||
// NOTE: The index starts from 1. As a result we enforce that specifying a value
|
// NOTE: The index starts from 1. As a result we enforce that specifying a value
|
||||||
// below the starting index value is a noop.
|
// below the starting index value is a noop.
|
||||||
func (i *InvoiceStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
|
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
|
||||||
[]invpkg.Invoice, error) {
|
[]Invoice, error) {
|
||||||
|
|
||||||
var result []invpkg.Invoice
|
var result []Invoice
|
||||||
|
|
||||||
if idx == 0 {
|
if idx == 0 {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
readTxOpt := NewInvoiceQueryReadTx()
|
readTxOpt := NewSQLInvoiceQueryReadTx()
|
||||||
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
|
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
|
||||||
result = nil
|
result = nil
|
||||||
addIdx := idx
|
addIdx := idx
|
||||||
limit := queryPaginationLimit
|
limit := queryPaginationLimit
|
||||||
|
|
||||||
return queryWithLimit(func(offset int) (int, error) {
|
return queryWithLimit(func(offset int) (int, error) {
|
||||||
params := sqlc.FilterInvoicesParams{
|
params := sqlc.FilterInvoicesParams{
|
||||||
AddIndexGet: SQLInt64(addIdx + 1),
|
AddIndexGet: sqldb.SQLInt64(addIdx + 1),
|
||||||
NumLimit: int32(limit),
|
NumLimit: int32(limit),
|
||||||
NumOffset: int32(offset),
|
NumOffset: int32(offset),
|
||||||
}
|
}
|
||||||
@ -831,18 +833,18 @@ func (i *InvoiceStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
|
|||||||
|
|
||||||
// QueryInvoices allows a caller to query the invoice database for invoices
|
// QueryInvoices allows a caller to query the invoice database for invoices
|
||||||
// within the specified add index range.
|
// within the specified add index range.
|
||||||
func (i *InvoiceStore) QueryInvoices(ctx context.Context,
|
func (i *SQLStore) QueryInvoices(ctx context.Context,
|
||||||
q invpkg.InvoiceQuery) (invpkg.InvoiceSlice, error) {
|
q InvoiceQuery) (InvoiceSlice, error) {
|
||||||
|
|
||||||
var invoices []invpkg.Invoice
|
var invoices []Invoice
|
||||||
|
|
||||||
if q.NumMaxInvoices == 0 {
|
if q.NumMaxInvoices == 0 {
|
||||||
return invpkg.InvoiceSlice{}, fmt.Errorf("max invoices must " +
|
return InvoiceSlice{}, fmt.Errorf("max invoices must " +
|
||||||
"be non-zero")
|
"be non-zero")
|
||||||
}
|
}
|
||||||
|
|
||||||
readTxOpt := NewInvoiceQueryReadTx()
|
readTxOpt := NewSQLInvoiceQueryReadTx()
|
||||||
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
|
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
|
||||||
invoices = nil
|
invoices = nil
|
||||||
limit := queryPaginationLimit
|
limit := queryPaginationLimit
|
||||||
|
|
||||||
@ -856,7 +858,7 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
|
|||||||
if !q.Reversed {
|
if !q.Reversed {
|
||||||
// The invoice with index offset id must not be
|
// The invoice with index offset id must not be
|
||||||
// included in the results.
|
// included in the results.
|
||||||
params.AddIndexGet = SQLInt64(
|
params.AddIndexGet = sqldb.SQLInt64(
|
||||||
q.IndexOffset + uint64(offset) + 1,
|
q.IndexOffset + uint64(offset) + 1,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -867,13 +869,13 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
|
|||||||
// If the index offset was not set, we want to
|
// If the index offset was not set, we want to
|
||||||
// fetch from the lastest invoice.
|
// fetch from the lastest invoice.
|
||||||
if idx == 0 {
|
if idx == 0 {
|
||||||
params.AddIndexLet = SQLInt64(
|
params.AddIndexLet = sqldb.SQLInt64(
|
||||||
int64(math.MaxInt64),
|
int64(math.MaxInt64),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// The invoice with index offset id must
|
// The invoice with index offset id must
|
||||||
// not be included in the results.
|
// not be included in the results.
|
||||||
params.AddIndexLet = SQLInt64(
|
params.AddIndexLet = sqldb.SQLInt64(
|
||||||
idx - int32(offset) - 1,
|
idx - int32(offset) - 1,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -882,13 +884,13 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if q.CreationDateStart != 0 {
|
if q.CreationDateStart != 0 {
|
||||||
params.CreatedAfter = SQLTime(
|
params.CreatedAfter = sqldb.SQLTime(
|
||||||
time.Unix(q.CreationDateStart, 0).UTC(),
|
time.Unix(q.CreationDateStart, 0).UTC(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if q.CreationDateEnd != 0 {
|
if q.CreationDateEnd != 0 {
|
||||||
params.CreatedBefore = SQLTime(
|
params.CreatedBefore = sqldb.SQLTime(
|
||||||
time.Unix(q.CreationDateEnd, 0).UTC(),
|
time.Unix(q.CreationDateEnd, 0).UTC(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -919,12 +921,12 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
|
|||||||
}, limit)
|
}, limit)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return invpkg.InvoiceSlice{}, fmt.Errorf("unable to query "+
|
return InvoiceSlice{}, fmt.Errorf("unable to query "+
|
||||||
"invoices: %w", err)
|
"invoices: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(invoices) == 0 {
|
if len(invoices) == 0 {
|
||||||
return invpkg.InvoiceSlice{
|
return InvoiceSlice{
|
||||||
InvoiceQuery: q,
|
InvoiceQuery: q,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -941,7 +943,7 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
res := invpkg.InvoiceSlice{
|
res := InvoiceSlice{
|
||||||
InvoiceQuery: q,
|
InvoiceQuery: q,
|
||||||
Invoices: invoices,
|
Invoices: invoices,
|
||||||
FirstIndexOffset: invoices[0].AddIndex,
|
FirstIndexOffset: invoices[0].AddIndex,
|
||||||
@ -954,15 +956,15 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
|
|||||||
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
|
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
|
||||||
// a SQL database as the backend.
|
// a SQL database as the backend.
|
||||||
type sqlInvoiceUpdater struct {
|
type sqlInvoiceUpdater struct {
|
||||||
db InvoiceQueries
|
db SQLInvoiceQueries
|
||||||
ctx context.Context //nolint:containedctx
|
ctx context.Context //nolint:containedctx
|
||||||
invoice *invpkg.Invoice
|
invoice *Invoice
|
||||||
updateTime time.Time
|
updateTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddHtlc adds a new htlc to the invoice.
|
// AddHtlc adds a new htlc to the invoice.
|
||||||
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
|
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
|
||||||
newHtlc *invpkg.InvoiceHTLC) error {
|
newHtlc *InvoiceHTLC) error {
|
||||||
|
|
||||||
htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
|
htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
|
||||||
s.ctx, sqlc.InsertInvoiceHTLCParams{
|
s.ctx, sqlc.InsertInvoiceHTLCParams{
|
||||||
@ -1012,10 +1014,10 @@ func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mappedSQLErr := MapSQLError(err)
|
mappedSQLErr := sqldb.MapSQLError(err)
|
||||||
var uniqueConstraintErr *ErrSQLUniqueConstraintViolation
|
var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:lll
|
||||||
if errors.As(mappedSQLErr, &uniqueConstraintErr) {
|
if errors.As(mappedSQLErr, &uniqueConstraintErr) {
|
||||||
return invpkg.ErrDuplicateSetID{
|
return ErrDuplicateSetID{
|
||||||
SetID: setID,
|
SetID: setID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1072,7 +1074,7 @@ func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
|
|||||||
|
|
||||||
// ResolveHtlc marks an htlc as resolved with the given state.
|
// ResolveHtlc marks an htlc as resolved with the given state.
|
||||||
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
|
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
|
||||||
state invpkg.HtlcState, resolveTime time.Time) error {
|
state HtlcState, resolveTime time.Time) error {
|
||||||
|
|
||||||
return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
|
return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
|
||||||
HtlcID: int64(circuitKey.HtlcID),
|
HtlcID: int64(circuitKey.HtlcID),
|
||||||
@ -1081,7 +1083,7 @@ func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
|
|||||||
),
|
),
|
||||||
InvoiceID: int64(s.invoice.AddIndex),
|
InvoiceID: int64(s.invoice.AddIndex),
|
||||||
State: int16(state),
|
State: int16(state),
|
||||||
ResolveTime: SQLTime(resolveTime.UTC()),
|
ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1107,7 +1109,7 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return invpkg.ErrInvoiceNotFound
|
return ErrInvoiceNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -1115,7 +1117,7 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
|
|||||||
|
|
||||||
// UpdateInvoiceState updates the invoice state to the new state.
|
// UpdateInvoiceState updates the invoice state to the new state.
|
||||||
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
|
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
|
||||||
newState invpkg.ContractState, preimage *lntypes.Preimage) error {
|
newState ContractState, preimage *lntypes.Preimage) error {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
settleIndex sql.NullInt64
|
settleIndex sql.NullInt64
|
||||||
@ -1123,16 +1125,16 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState(
|
|||||||
)
|
)
|
||||||
|
|
||||||
switch newState {
|
switch newState {
|
||||||
case invpkg.ContractSettled:
|
case ContractSettled:
|
||||||
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
|
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
settleIndex = SQLInt64(nextSettleIndex)
|
settleIndex = sqldb.SQLInt64(nextSettleIndex)
|
||||||
|
|
||||||
// If the invoice is settled, we'll also update the settle time.
|
// If the invoice is settled, we'll also update the settle time.
|
||||||
settledAt = SQLTime(s.updateTime.UTC())
|
settledAt = sqldb.SQLTime(s.updateTime.UTC())
|
||||||
|
|
||||||
err = s.db.OnInvoiceSettled(
|
err = s.db.OnInvoiceSettled(
|
||||||
s.ctx, sqlc.OnInvoiceSettledParams{
|
s.ctx, sqlc.OnInvoiceSettledParams{
|
||||||
@ -1144,7 +1146,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
case invpkg.ContractCanceled:
|
case ContractCanceled:
|
||||||
err := s.db.OnInvoiceCanceled(
|
err := s.db.OnInvoiceCanceled(
|
||||||
s.ctx, sqlc.OnInvoiceCanceledParams{
|
s.ctx, sqlc.OnInvoiceCanceledParams{
|
||||||
AddedAt: s.updateTime.UTC(),
|
AddedAt: s.updateTime.UTC(),
|
||||||
@ -1177,7 +1179,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return invpkg.ErrInvoiceNotFound
|
return ErrInvoiceNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
if settleIndex.Valid {
|
if settleIndex.Valid {
|
||||||
@ -1205,7 +1207,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
|
|||||||
// UpdateAmpState updates the state of the AMP sub invoice identified by the
|
// UpdateAmpState updates the state of the AMP sub invoice identified by the
|
||||||
// setID.
|
// setID.
|
||||||
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
|
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
|
||||||
newState invpkg.InvoiceStateAMP, _ models.CircuitKey) error {
|
newState InvoiceStateAMP, _ models.CircuitKey) error {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
settleIndex sql.NullInt64
|
settleIndex sql.NullInt64
|
||||||
@ -1213,16 +1215,16 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
|
|||||||
)
|
)
|
||||||
|
|
||||||
switch newState.State {
|
switch newState.State {
|
||||||
case invpkg.HtlcStateSettled:
|
case HtlcStateSettled:
|
||||||
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
|
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
settleIndex = SQLInt64(nextSettleIndex)
|
settleIndex = sqldb.SQLInt64(nextSettleIndex)
|
||||||
|
|
||||||
// If the invoice is settled, we'll also update the settle time.
|
// If the invoice is settled, we'll also update the settle time.
|
||||||
settledAt = SQLTime(s.updateTime.UTC())
|
settledAt = sqldb.SQLTime(s.updateTime.UTC())
|
||||||
|
|
||||||
err = s.db.OnAMPSubInvoiceSettled(
|
err = s.db.OnAMPSubInvoiceSettled(
|
||||||
s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
|
s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
|
||||||
@ -1235,7 +1237,7 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
case invpkg.HtlcStateCanceled:
|
case HtlcStateCanceled:
|
||||||
err := s.db.OnAMPSubInvoiceCanceled(
|
err := s.db.OnAMPSubInvoiceCanceled(
|
||||||
s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
|
s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
|
||||||
AddedAt: s.updateTime.UTC(),
|
AddedAt: s.updateTime.UTC(),
|
||||||
@ -1266,7 +1268,7 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
|
|||||||
// Finalize finalizes the update before it is written to the database. Note that
|
// Finalize finalizes the update before it is written to the database. Note that
|
||||||
// we don't use this directly in the SQL implementation, so the function is just
|
// we don't use this directly in the SQL implementation, so the function is just
|
||||||
// a stub.
|
// a stub.
|
||||||
func (s *sqlInvoiceUpdater) Finalize(_ invpkg.UpdateType) error {
|
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1277,14 +1279,14 @@ func (s *sqlInvoiceUpdater) Finalize(_ invpkg.UpdateType) error {
|
|||||||
// The update is performed inside the same database transaction that fetches the
|
// The update is performed inside the same database transaction that fetches the
|
||||||
// invoice and is therefore atomic. The fields to update are controlled by the
|
// invoice and is therefore atomic. The fields to update are controlled by the
|
||||||
// supplied callback.
|
// supplied callback.
|
||||||
func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
|
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
|
||||||
_ *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
|
_ *SetID, callback InvoiceUpdateCallback) (
|
||||||
*invpkg.Invoice, error) {
|
*Invoice, error) {
|
||||||
|
|
||||||
var updatedInvoice *invpkg.Invoice
|
var updatedInvoice *Invoice
|
||||||
|
|
||||||
txOpt := InvoiceQueriesTxOptions{readOnly: false}
|
txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
|
||||||
txErr := i.db.ExecTx(ctx, &txOpt, func(db InvoiceQueries) error {
|
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
|
||||||
invoice, err := i.fetchInvoice(ctx, db, ref)
|
invoice, err := i.fetchInvoice(ctx, db, ref)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -1299,7 +1301,7 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
|
|||||||
}
|
}
|
||||||
|
|
||||||
payHash := ref.PayHash()
|
payHash := ref.PayHash()
|
||||||
updatedInvoice, err = invpkg.UpdateInvoice(
|
updatedInvoice, err = UpdateInvoice(
|
||||||
payHash, invoice, updateTime, callback, updater,
|
payHash, invoice, updateTime, callback, updater,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1308,7 +1310,7 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
|
|||||||
if txErr != nil {
|
if txErr != nil {
|
||||||
// If the invoice is already settled, we'll return the
|
// If the invoice is already settled, we'll return the
|
||||||
// (unchanged) invoice and the ErrInvoiceAlreadySettled error.
|
// (unchanged) invoice and the ErrInvoiceAlreadySettled error.
|
||||||
if errors.Is(txErr, invpkg.ErrInvoiceAlreadySettled) {
|
if errors.Is(txErr, ErrInvoiceAlreadySettled) {
|
||||||
return updatedInvoice, txErr
|
return updatedInvoice, txErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1320,8 +1322,8 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
|
|||||||
|
|
||||||
// DeleteInvoice attempts to delete the passed invoices and all their related
|
// DeleteInvoice attempts to delete the passed invoices and all their related
|
||||||
// data from the database in one transaction.
|
// data from the database in one transaction.
|
||||||
func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
|
func (i *SQLStore) DeleteInvoice(ctx context.Context,
|
||||||
invoicesToDelete []invpkg.InvoiceDeleteRef) error {
|
invoicesToDelete []InvoiceDeleteRef) error {
|
||||||
|
|
||||||
// All the InvoiceDeleteRef instances include the add index of the
|
// All the InvoiceDeleteRef instances include the add index of the
|
||||||
// invoice. The rest was added to ensure that the invoices were deleted
|
// invoice. The rest was added to ensure that the invoices were deleted
|
||||||
@ -1334,15 +1336,17 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var writeTxOpt InvoiceQueriesTxOptions
|
var writeTxOpt SQLInvoiceQueriesTxOptions
|
||||||
err := i.db.ExecTx(ctx, &writeTxOpt, func(db InvoiceQueries) error {
|
err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
|
||||||
for _, ref := range invoicesToDelete {
|
for _, ref := range invoicesToDelete {
|
||||||
params := sqlc.DeleteInvoiceParams{
|
params := sqlc.DeleteInvoiceParams{
|
||||||
AddIndex: SQLInt64(ref.AddIndex),
|
AddIndex: sqldb.SQLInt64(ref.AddIndex),
|
||||||
}
|
}
|
||||||
|
|
||||||
if ref.SettleIndex != 0 {
|
if ref.SettleIndex != 0 {
|
||||||
params.SettleIndex = SQLInt64(ref.SettleIndex)
|
params.SettleIndex = sqldb.SQLInt64(
|
||||||
|
ref.SettleIndex,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ref.PayHash != lntypes.ZeroHash {
|
if ref.PayHash != lntypes.ZeroHash {
|
||||||
@ -1361,7 +1365,7 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
|
|||||||
}
|
}
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return fmt.Errorf("%w: %v",
|
return fmt.Errorf("%w: %v",
|
||||||
invpkg.ErrInvoiceNotFound, ref.AddIndex)
|
ErrInvoiceNotFound, ref.AddIndex)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1376,9 +1380,9 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeleteCanceledInvoices removes all canceled invoices from the database.
|
// DeleteCanceledInvoices removes all canceled invoices from the database.
|
||||||
func (i *InvoiceStore) DeleteCanceledInvoices(ctx context.Context) error {
|
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
|
||||||
var writeTxOpt InvoiceQueriesTxOptions
|
var writeTxOpt SQLInvoiceQueriesTxOptions
|
||||||
err := i.db.ExecTx(ctx, &writeTxOpt, func(db InvoiceQueries) error {
|
err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
|
||||||
_, err := db.DeleteCanceledInvoices(ctx)
|
_, err := db.DeleteCanceledInvoices(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to delete canceled "+
|
return fmt.Errorf("unable to delete canceled "+
|
||||||
@ -1398,9 +1402,9 @@ func (i *InvoiceStore) DeleteCanceledInvoices(ctx context.Context) error {
|
|||||||
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
|
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
|
||||||
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
|
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
|
||||||
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
|
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
|
||||||
func fetchInvoiceData(ctx context.Context, db InvoiceQueries,
|
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
|
||||||
row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
|
row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
|
||||||
*invpkg.Invoice, error) {
|
*Invoice, error) {
|
||||||
|
|
||||||
// Unmarshal the common data.
|
// Unmarshal the common data.
|
||||||
hash, invoice, err := unmarshalInvoice(row)
|
hash, invoice, err := unmarshalInvoice(row)
|
||||||
@ -1444,7 +1448,7 @@ func fetchInvoiceData(ctx context.Context, db InvoiceQueries,
|
|||||||
invoice.Htlcs = htlcs
|
invoice.Htlcs = htlcs
|
||||||
var amountPaid lnwire.MilliSatoshi
|
var amountPaid lnwire.MilliSatoshi
|
||||||
for _, htlc := range htlcs {
|
for _, htlc := range htlcs {
|
||||||
if htlc.State == invpkg.HtlcStateSettled {
|
if htlc.State == HtlcStateSettled {
|
||||||
amountPaid += htlc.Amt
|
amountPaid += htlc.Amt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1455,7 +1459,7 @@ func fetchInvoiceData(ctx context.Context, db InvoiceQueries,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getInvoiceFeatures fetches the invoice features for the given invoice id.
|
// getInvoiceFeatures fetches the invoice features for the given invoice id.
|
||||||
func getInvoiceFeatures(ctx context.Context, db InvoiceQueries,
|
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
|
||||||
invoiceID int64) (*lnwire.FeatureVector, error) {
|
invoiceID int64) (*lnwire.FeatureVector, error) {
|
||||||
|
|
||||||
rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
|
rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
|
||||||
@ -1473,8 +1477,8 @@ func getInvoiceFeatures(ctx context.Context, db InvoiceQueries,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
|
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
|
||||||
func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries,
|
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
|
||||||
invoiceID int64) (map[invpkg.CircuitKey]*invpkg.InvoiceHTLC, error) {
|
invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
|
||||||
|
|
||||||
htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
|
htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1505,7 +1509,7 @@ func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries,
|
|||||||
cr[row.HtlcID][uint64(row.Key)] = value
|
cr[row.HtlcID][uint64(row.Key)] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
htlcs := make(map[invpkg.CircuitKey]*invpkg.InvoiceHTLC, len(htlcRows))
|
htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
|
||||||
|
|
||||||
for _, row := range htlcRows {
|
for _, row := range htlcRows {
|
||||||
circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
|
circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
|
||||||
@ -1527,7 +1531,7 @@ func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// unmarshalInvoice converts an InvoiceRow to an Invoice.
|
// unmarshalInvoice converts an InvoiceRow to an Invoice.
|
||||||
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
|
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
|
||||||
error) {
|
error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -1576,13 +1580,13 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
|
|||||||
cltvDelta = row.CltvDelta.Int32
|
cltvDelta = row.CltvDelta.Int32
|
||||||
}
|
}
|
||||||
|
|
||||||
invoice := &invpkg.Invoice{
|
invoice := &Invoice{
|
||||||
SettleIndex: uint64(settleIndex),
|
SettleIndex: uint64(settleIndex),
|
||||||
SettleDate: settledAt,
|
SettleDate: settledAt,
|
||||||
Memo: memo,
|
Memo: memo,
|
||||||
PaymentRequest: paymentRequest,
|
PaymentRequest: paymentRequest,
|
||||||
CreationDate: row.CreatedAt.Local(),
|
CreationDate: row.CreatedAt.Local(),
|
||||||
Terms: invpkg.ContractTerm{
|
Terms: ContractTerm{
|
||||||
FinalCltvDelta: cltvDelta,
|
FinalCltvDelta: cltvDelta,
|
||||||
Expiry: time.Duration(row.Expiry),
|
Expiry: time.Duration(row.Expiry),
|
||||||
PaymentPreimage: preimage,
|
PaymentPreimage: preimage,
|
||||||
@ -1590,10 +1594,10 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
|
|||||||
PaymentAddr: paymentAddr,
|
PaymentAddr: paymentAddr,
|
||||||
},
|
},
|
||||||
AddIndex: uint64(row.ID),
|
AddIndex: uint64(row.ID),
|
||||||
State: invpkg.ContractState(row.State),
|
State: ContractState(row.State),
|
||||||
AmtPaid: lnwire.MilliSatoshi(row.AmountPaidMsat),
|
AmtPaid: lnwire.MilliSatoshi(row.AmountPaidMsat),
|
||||||
Htlcs: make(map[models.CircuitKey]*invpkg.InvoiceHTLC),
|
Htlcs: make(map[models.CircuitKey]*InvoiceHTLC),
|
||||||
AMPState: invpkg.AMPInvoiceState{},
|
AMPState: AMPInvoiceState{},
|
||||||
HodlInvoice: row.IsHodl,
|
HodlInvoice: row.IsHodl,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1601,34 +1605,34 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
|
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
|
||||||
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (invpkg.CircuitKey,
|
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
|
||||||
*invpkg.InvoiceHTLC, error) {
|
*InvoiceHTLC, error) {
|
||||||
|
|
||||||
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
|
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return invpkg.CircuitKey{}, nil, err
|
return CircuitKey{}, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
|
chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
|
||||||
|
|
||||||
if row.HtlcID < 0 {
|
if row.HtlcID < 0 {
|
||||||
return invpkg.CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
|
return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
|
||||||
"value: %v", row.HtlcID)
|
"value: %v", row.HtlcID)
|
||||||
}
|
}
|
||||||
|
|
||||||
htlcID := uint64(row.HtlcID)
|
htlcID := uint64(row.HtlcID)
|
||||||
|
|
||||||
circuitKey := invpkg.CircuitKey{
|
circuitKey := CircuitKey{
|
||||||
ChanID: chanID,
|
ChanID: chanID,
|
||||||
HtlcID: htlcID,
|
HtlcID: htlcID,
|
||||||
}
|
}
|
||||||
|
|
||||||
htlc := &invpkg.InvoiceHTLC{
|
htlc := &InvoiceHTLC{
|
||||||
Amt: lnwire.MilliSatoshi(row.AmountMsat),
|
Amt: lnwire.MilliSatoshi(row.AmountMsat),
|
||||||
AcceptHeight: uint32(row.AcceptHeight),
|
AcceptHeight: uint32(row.AcceptHeight),
|
||||||
AcceptTime: row.AcceptTime.Local(),
|
AcceptTime: row.AcceptTime.Local(),
|
||||||
Expiry: uint32(row.ExpiryHeight),
|
Expiry: uint32(row.ExpiryHeight),
|
||||||
State: invpkg.HtlcState(row.State),
|
State: HtlcState(row.State),
|
||||||
}
|
}
|
||||||
|
|
||||||
if row.TotalMppMsat.Valid {
|
if row.TotalMppMsat.Valid {
|
Reference in New Issue
Block a user