mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-03 08:20:30 +02:00
sqldb+invoices: synchronize SQL invoice updater behavior with KV version
Previously SQL invoice updater ignored the set ID hint when updating an AMP invoice resulting in update subscriptions returning all of the AMP state as well as all AMP HTLCs. This commit synchornizes behavior with the KV implementation such that we now only return relevant AMP state and HTLCs when updating an AMP invoice.
This commit is contained in:
parent
c8de7a1699
commit
b57910ee3a
@ -10,6 +10,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||||
"github.com/lightningnetwork/lnd/clock"
|
"github.com/lightningnetwork/lnd/clock"
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
|
|||||||
GetInvoice(ctx context.Context,
|
GetInvoice(ctx context.Context,
|
||||||
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
|
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
|
||||||
|
|
||||||
|
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
|
||||||
|
error)
|
||||||
|
|
||||||
GetInvoiceFeatures(ctx context.Context,
|
GetInvoiceFeatures(ctx context.Context,
|
||||||
invoiceID int64) ([]sqlc.InvoiceFeature, error)
|
invoiceID int64) ([]sqlc.InvoiceFeature, error)
|
||||||
|
|
||||||
@ -343,7 +347,22 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
|
|||||||
params.SetID = ref.SetID()[:]
|
params.SetID = ref.SetID()[:]
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := db.GetInvoice(ctx, params)
|
var (
|
||||||
|
rows []sqlc.Invoice
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
// We need to split the query based on how we intend to look up the
|
||||||
|
// invoice. If only the set ID is given then we want to have an exact
|
||||||
|
// match on the set ID. If other fields are given, we want to match on
|
||||||
|
// those fields and the set ID but with a less strict join condition.
|
||||||
|
if params.Hash == nil && params.PaymentAddr == nil &&
|
||||||
|
params.SetID != nil {
|
||||||
|
|
||||||
|
rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
|
||||||
|
} else {
|
||||||
|
rows, err = db.GetInvoice(ctx, params)
|
||||||
|
}
|
||||||
switch {
|
switch {
|
||||||
case len(rows) == 0:
|
case len(rows) == 0:
|
||||||
return nil, ErrInvoiceNotFound
|
return nil, ErrInvoiceNotFound
|
||||||
@ -351,8 +370,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
|
|||||||
case len(rows) > 1:
|
case len(rows) > 1:
|
||||||
// In case the reference is ambiguous, meaning it matches more
|
// In case the reference is ambiguous, meaning it matches more
|
||||||
// than one invoice, we'll return an error.
|
// than one invoice, we'll return an error.
|
||||||
return nil, fmt.Errorf("ambiguous invoice ref: %s",
|
return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
|
||||||
ref.String())
|
ref.String(), spew.Sdump(rows))
|
||||||
|
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return nil, fmt.Errorf("unable to fetch invoice: %w", err)
|
return nil, fmt.Errorf("unable to fetch invoice: %w", err)
|
||||||
@ -1308,13 +1327,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
|
|||||||
// invoice and is therefore atomic. The fields to update are controlled by the
|
// invoice and is therefore atomic. The fields to update are controlled by the
|
||||||
// supplied callback.
|
// supplied callback.
|
||||||
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
|
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
|
||||||
_ *SetID, callback InvoiceUpdateCallback) (
|
setID *SetID, callback InvoiceUpdateCallback) (
|
||||||
*Invoice, error) {
|
*Invoice, error) {
|
||||||
|
|
||||||
var updatedInvoice *Invoice
|
var updatedInvoice *Invoice
|
||||||
|
|
||||||
txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
|
txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
|
||||||
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
|
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
|
||||||
|
if setID != nil {
|
||||||
|
// Make sure to use the set ID if this is an AMP update.
|
||||||
|
var setIDBytes [32]byte
|
||||||
|
copy(setIDBytes[:], setID[:])
|
||||||
|
ref.setID = &setIDBytes
|
||||||
|
|
||||||
|
// If we're updating an AMP invoice, we'll also only
|
||||||
|
// need to fetch the HTLCs for the given set ID.
|
||||||
|
ref.refModifier = HtlcSetOnlyModifier
|
||||||
|
}
|
||||||
|
|
||||||
invoice, err := i.fetchInvoice(ctx, db, ref)
|
invoice, err := i.fetchInvoice(ctx, db, ref)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -170,21 +170,22 @@ const getInvoice = `-- name: GetInvoice :many
|
|||||||
|
|
||||||
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
|
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
|
||||||
FROM invoices i
|
FROM invoices i
|
||||||
LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id
|
LEFT JOIN amp_sub_invoices a
|
||||||
|
ON i.id = a.invoice_id
|
||||||
|
AND (
|
||||||
|
a.set_id = $1 OR $1 IS NULL
|
||||||
|
)
|
||||||
WHERE (
|
WHERE (
|
||||||
i.id = $1 OR
|
i.id = $2 OR
|
||||||
$1 IS NULL
|
|
||||||
) AND (
|
|
||||||
i.hash = $2 OR
|
|
||||||
$2 IS NULL
|
$2 IS NULL
|
||||||
) AND (
|
) AND (
|
||||||
i.preimage = $3 OR
|
i.hash = $3 OR
|
||||||
$3 IS NULL
|
$3 IS NULL
|
||||||
) AND (
|
) AND (
|
||||||
i.payment_addr = $4 OR
|
i.preimage = $4 OR
|
||||||
$4 IS NULL
|
$4 IS NULL
|
||||||
) AND (
|
) AND (
|
||||||
a.set_id = $5 OR
|
i.payment_addr = $5 OR
|
||||||
$5 IS NULL
|
$5 IS NULL
|
||||||
)
|
)
|
||||||
GROUP BY i.id
|
GROUP BY i.id
|
||||||
@ -192,11 +193,11 @@ LIMIT 2
|
|||||||
`
|
`
|
||||||
|
|
||||||
type GetInvoiceParams struct {
|
type GetInvoiceParams struct {
|
||||||
|
SetID []byte
|
||||||
AddIndex sql.NullInt64
|
AddIndex sql.NullInt64
|
||||||
Hash []byte
|
Hash []byte
|
||||||
Preimage []byte
|
Preimage []byte
|
||||||
PaymentAddr []byte
|
PaymentAddr []byte
|
||||||
SetID []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This method may return more than one invoice if filter using multiple fields
|
// This method may return more than one invoice if filter using multiple fields
|
||||||
@ -204,11 +205,11 @@ type GetInvoiceParams struct {
|
|||||||
// we bubble up an error in those cases.
|
// we bubble up an error in those cases.
|
||||||
func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) {
|
func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) {
|
||||||
rows, err := q.db.QueryContext(ctx, getInvoice,
|
rows, err := q.db.QueryContext(ctx, getInvoice,
|
||||||
|
arg.SetID,
|
||||||
arg.AddIndex,
|
arg.AddIndex,
|
||||||
arg.Hash,
|
arg.Hash,
|
||||||
arg.Preimage,
|
arg.Preimage,
|
||||||
arg.PaymentAddr,
|
arg.PaymentAddr,
|
||||||
arg.SetID,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -250,6 +251,55 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
|
|||||||
return items, nil
|
return items, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many
|
||||||
|
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
|
||||||
|
FROM invoices i
|
||||||
|
INNER JOIN amp_sub_invoices a
|
||||||
|
ON i.id = a.invoice_id AND a.set_id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) {
|
||||||
|
rows, err := q.db.QueryContext(ctx, getInvoiceBySetID, setID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []Invoice
|
||||||
|
for rows.Next() {
|
||||||
|
var i Invoice
|
||||||
|
if err := rows.Scan(
|
||||||
|
&i.ID,
|
||||||
|
&i.Hash,
|
||||||
|
&i.Preimage,
|
||||||
|
&i.SettleIndex,
|
||||||
|
&i.SettledAt,
|
||||||
|
&i.Memo,
|
||||||
|
&i.AmountMsat,
|
||||||
|
&i.CltvDelta,
|
||||||
|
&i.Expiry,
|
||||||
|
&i.PaymentAddr,
|
||||||
|
&i.PaymentRequest,
|
||||||
|
&i.PaymentRequestHash,
|
||||||
|
&i.State,
|
||||||
|
&i.AmountPaidMsat,
|
||||||
|
&i.IsAmp,
|
||||||
|
&i.IsHodl,
|
||||||
|
&i.IsKeysend,
|
||||||
|
&i.CreatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, i)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
const getInvoiceFeatures = `-- name: GetInvoiceFeatures :many
|
const getInvoiceFeatures = `-- name: GetInvoiceFeatures :many
|
||||||
SELECT feature, invoice_id
|
SELECT feature, invoice_id
|
||||||
FROM invoice_features
|
FROM invoice_features
|
||||||
|
@ -21,6 +21,7 @@ type Querier interface {
|
|||||||
// from different invoices. It is the caller's responsibility to ensure that
|
// from different invoices. It is the caller's responsibility to ensure that
|
||||||
// we bubble up an error in those cases.
|
// we bubble up an error in those cases.
|
||||||
GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error)
|
GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error)
|
||||||
|
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error)
|
||||||
GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error)
|
GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error)
|
||||||
GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error)
|
GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error)
|
||||||
GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error)
|
GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error)
|
||||||
|
@ -26,7 +26,11 @@ WHERE invoice_id = $1;
|
|||||||
-- name: GetInvoice :many
|
-- name: GetInvoice :many
|
||||||
SELECT i.*
|
SELECT i.*
|
||||||
FROM invoices i
|
FROM invoices i
|
||||||
LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id
|
LEFT JOIN amp_sub_invoices a
|
||||||
|
ON i.id = a.invoice_id
|
||||||
|
AND (
|
||||||
|
a.set_id = sqlc.narg('set_id') OR sqlc.narg('set_id') IS NULL
|
||||||
|
)
|
||||||
WHERE (
|
WHERE (
|
||||||
i.id = sqlc.narg('add_index') OR
|
i.id = sqlc.narg('add_index') OR
|
||||||
sqlc.narg('add_index') IS NULL
|
sqlc.narg('add_index') IS NULL
|
||||||
@ -39,13 +43,16 @@ WHERE (
|
|||||||
) AND (
|
) AND (
|
||||||
i.payment_addr = sqlc.narg('payment_addr') OR
|
i.payment_addr = sqlc.narg('payment_addr') OR
|
||||||
sqlc.narg('payment_addr') IS NULL
|
sqlc.narg('payment_addr') IS NULL
|
||||||
) AND (
|
|
||||||
a.set_id = sqlc.narg('set_id') OR
|
|
||||||
sqlc.narg('set_id') IS NULL
|
|
||||||
)
|
)
|
||||||
GROUP BY i.id
|
GROUP BY i.id
|
||||||
LIMIT 2;
|
LIMIT 2;
|
||||||
|
|
||||||
|
-- name: GetInvoiceBySetID :many
|
||||||
|
SELECT i.*
|
||||||
|
FROM invoices i
|
||||||
|
INNER JOIN amp_sub_invoices a
|
||||||
|
ON i.id = a.invoice_id AND a.set_id = $1;
|
||||||
|
|
||||||
-- name: FilterInvoices :many
|
-- name: FilterInvoices :many
|
||||||
SELECT
|
SELECT
|
||||||
invoices.*
|
invoices.*
|
||||||
|
Loading…
x
Reference in New Issue
Block a user