diff --git a/invoices/invoices.go b/invoices/invoices.go index c48629c58..32164cbe1 100644 --- a/invoices/invoices.go +++ b/invoices/invoices.go @@ -187,6 +187,11 @@ func (r InvoiceRef) Modifier() RefModifier { return r.refModifier } +// IsHashOnly returns true if the invoice ref only contains a payment hash. +func (r InvoiceRef) IsHashOnly() bool { + return r.payHash != nil && r.payAddr == nil && r.setID == nil +} + // String returns a human-readable representation of an InvoiceRef. func (r InvoiceRef) String() string { var ids []string diff --git a/invoices/sql_store.go b/invoices/sql_store.go index 8a819e5ba..f7ca02637 100644 --- a/invoices/sql_store.go +++ b/invoices/sql_store.go @@ -51,6 +51,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat GetInvoice(ctx context.Context, arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error) + GetInvoiceByHash(ctx context.Context, hash []byte) (sqlc.Invoice, + error) + GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice, error) @@ -354,22 +357,31 @@ func (i *SQLStore) AddInvoice(ctx context.Context, return newInvoice.AddIndex, nil } -// fetchInvoice fetches the common invoice data and the AMP state for the -// invoice with the given reference. -func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, - ref InvoiceRef) (*Invoice, error) { +// getInvoiceByRef fetches the invoice with the given reference. The reference +// may be a payment hash, a payment address, or a set ID for an AMP sub invoice. +func getInvoiceByRef(ctx context.Context, + db SQLInvoiceQueries, ref InvoiceRef) (sqlc.Invoice, error) { + // If the reference is empty, we can't look up the invoice. if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil { - return nil, ErrInvoiceNotFound + return sqlc.Invoice{}, ErrInvoiceNotFound } - var ( - invoice *Invoice - params sqlc.GetInvoiceParams - ) + // If the reference is a hash only, we can look up the invoice directly + // by the payment hash which is faster. + if ref.IsHashOnly() { + invoice, err := db.GetInvoiceByHash(ctx, ref.PayHash()[:]) + if errors.Is(err, sql.ErrNoRows) { + return sqlc.Invoice{}, ErrInvoiceNotFound + } + + return invoice, err + } + + // Otherwise the reference may include more fields, so we'll need to + // assemble the query parameters based on the fields that are set. + var params sqlc.GetInvoiceParams - // Given all invoices are uniquely identified by their payment hash, - // we can use it to query a specific invoice. if ref.PayHash() != nil { params.Hash = ref.PayHash()[:] } @@ -405,18 +417,34 @@ func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, } else { rows, err = db.GetInvoice(ctx, params) } + switch { case len(rows) == 0: - return nil, ErrInvoiceNotFound + return sqlc.Invoice{}, ErrInvoiceNotFound case len(rows) > 1: // In case the reference is ambiguous, meaning it matches more // than one invoice, we'll return an error. - return nil, fmt.Errorf("ambiguous invoice ref: %s: %s", - ref.String(), spew.Sdump(rows)) + return sqlc.Invoice{}, fmt.Errorf("ambiguous invoice ref: "+ + "%s: %s", ref.String(), spew.Sdump(rows)) case err != nil: - return nil, fmt.Errorf("unable to fetch invoice: %w", err) + return sqlc.Invoice{}, fmt.Errorf("unable to fetch invoice: %w", + err) + } + + return rows[0], nil +} + +// fetchInvoice fetches the common invoice data and the AMP state for the +// invoice with the given reference. +func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, ref InvoiceRef) ( + *Invoice, error) { + + // Fetch the invoice from the database. + sqlInvoice, err := getInvoiceByRef(ctx, db, ref) + if err != nil { + return nil, err } var ( @@ -433,8 +461,8 @@ func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, fetchAmpHtlcs = true case HtlcSetOnlyModifier: - // In this case we'll fetch all AMP HTLCs for the - // specified set id. + // In this case we'll fetch all AMP HTLCs for the specified set + // id. if ref.SetID() == nil { return nil, fmt.Errorf("set ID is required to use " + "the HTLC set only modifier") @@ -454,8 +482,8 @@ func fetchInvoice(ctx context.Context, db SQLInvoiceQueries, } // Fetch the rest of the invoice data and fill the invoice struct. - _, invoice, err = fetchInvoiceData( - ctx, db, rows[0], setID, fetchAmpHtlcs, + _, invoice, err := fetchInvoiceData( + ctx, db, sqlInvoice, setID, fetchAmpHtlcs, ) if err != nil { return nil, err @@ -658,7 +686,7 @@ func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64, invoiceKeys[key] = struct{}{} - if htlc.State != HtlcStateCanceled { //nolint: ll + if htlc.State != HtlcStateCanceled { amtPaid += htlc.Amt } } diff --git a/sqldb/sqlc/invoices.sql.go b/sqldb/sqlc/invoices.sql.go index b98a78c69..13ac8094a 100644 --- a/sqldb/sqlc/invoices.sql.go +++ b/sqldb/sqlc/invoices.sql.go @@ -255,6 +255,38 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi return items, nil } +const getInvoiceByHash = `-- name: GetInvoiceByHash :one +SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at +FROM invoices i +WHERE i.hash = $1 +` + +func (q *Queries) GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error) { + row := q.db.QueryRowContext(ctx, getInvoiceByHash, hash) + var i Invoice + err := row.Scan( + &i.ID, + &i.Hash, + &i.Preimage, + &i.SettleIndex, + &i.SettledAt, + &i.Memo, + &i.AmountMsat, + &i.CltvDelta, + &i.Expiry, + &i.PaymentAddr, + &i.PaymentRequest, + &i.PaymentRequestHash, + &i.State, + &i.AmountPaidMsat, + &i.IsAmp, + &i.IsHodl, + &i.IsKeysend, + &i.CreatedAt, + ) + return i, err +} + const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at FROM invoices i diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 6f05b32c2..c63f7fadb 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -24,6 +24,7 @@ type Querier interface { // from different invoices. It is the caller's responsibility to ensure that // we bubble up an error in those cases. GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) + GetInvoiceByHash(ctx context.Context, hash []byte) (Invoice, error) GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error) GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error) diff --git a/sqldb/sqlc/queries/invoices.sql b/sqldb/sqlc/queries/invoices.sql index f57c9ab76..db1f46e61 100644 --- a/sqldb/sqlc/queries/invoices.sql +++ b/sqldb/sqlc/queries/invoices.sql @@ -54,6 +54,11 @@ WHERE ( GROUP BY i.id LIMIT 2; +-- name: GetInvoiceByHash :one +SELECT i.* +FROM invoices i +WHERE i.hash = $1; + -- name: GetInvoiceBySetID :many SELECT i.* FROM invoices i