Merge pull request #9050 from bhandras/native-sql-invoice-fixes

invoices+sqldb: small fixes to address some inconsistencies between KV and native SQL invoice DB implementations
This commit is contained in:
Oliver Gugger
2024-09-04 01:30:52 -06:00
committed by GitHub
12 changed files with 197 additions and 41 deletions

View File

@@ -93,11 +93,13 @@ linters-settings:
- 'errors.Wrap' - 'errors.Wrap'
gomoddirectives: gomoddirectives:
replace-local: true
replace-allow-list: replace-allow-list:
# See go.mod for the explanation why these are needed. # See go.mod for the explanation why these are needed.
- github.com/ulikunitz/xz - github.com/ulikunitz/xz
- github.com/gogo/protobuf - github.com/gogo/protobuf
- google.golang.org/protobuf - google.golang.org/protobuf
- github.com/lightningnetwork/lnd/sqldb
linters: linters:

View File

@@ -269,7 +269,9 @@ func (d *DB) InvoicesAddedSince(_ context.Context, sinceAddIndex uint64) (
// For each key found, we'll look up the actual // For each key found, we'll look up the actual
// invoice, then accumulate it into our return value. // invoice, then accumulate it into our return value.
invoice, err := fetchInvoice(invoiceKey, invoices) invoice, err := fetchInvoice(
invoiceKey, invoices, nil, false,
)
if err != nil { if err != nil {
return err return err
} }
@@ -341,7 +343,9 @@ func (d *DB) LookupInvoice(_ context.Context, ref invpkg.InvoiceRef) (
// An invoice was found, retrieve the remainder of the invoice // An invoice was found, retrieve the remainder of the invoice
// body. // body.
i, err := fetchInvoice(invoiceNum, invoices, setID) i, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{setID}, true,
)
if err != nil { if err != nil {
return err return err
} }
@@ -468,7 +472,7 @@ func (d *DB) FetchPendingInvoices(_ context.Context) (
return nil return nil
} }
invoice, err := fetchInvoice(v, invoices) invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil { if err != nil {
return err return err
} }
@@ -526,7 +530,9 @@ func (d *DB) QueryInvoices(_ context.Context, q invpkg.InvoiceQuery) (
// characteristics for our query and returns the number of items // characteristics for our query and returns the number of items
// we have added to our set of invoices. // we have added to our set of invoices.
accumulateInvoices := func(_, indexValue []byte) (bool, error) { accumulateInvoices := func(_, indexValue []byte) (bool, error) {
invoice, err := fetchInvoice(indexValue, invoices) invoice, err := fetchInvoice(
indexValue, invoices, nil, false,
)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -654,7 +660,9 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
if setIDHint != nil { if setIDHint != nil {
invSetID = *setIDHint invSetID = *setIDHint
} }
invoice, err := fetchInvoice(invoiceNum, invoices, &invSetID) invoice, err := fetchInvoice(
invoiceNum, invoices, []*invpkg.SetID{&invSetID}, false,
)
if err != nil { if err != nil {
return err return err
} }
@@ -676,8 +684,17 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
updatedInvoice, err = invpkg.UpdateInvoice( updatedInvoice, err = invpkg.UpdateInvoice(
payHash, updater.invoice, now, callback, updater, payHash, updater.invoice, now, callback, updater,
) )
if err != nil {
return err
}
return err // If this is an AMP update, then limit the returned AMP state
// to only the requested set ID.
if setIDHint != nil {
filterInvoiceAMPState(updatedInvoice, &invSetID)
}
return nil
}, func() { }, func() {
updatedInvoice = nil updatedInvoice = nil
}) })
@@ -685,6 +702,25 @@ func (d *DB) UpdateInvoice(_ context.Context, ref invpkg.InvoiceRef,
return updatedInvoice, err return updatedInvoice, err
} }
// filterInvoiceAMPState filters the AMP state of the invoice to only include
// state for the specified set IDs.
func filterInvoiceAMPState(invoice *invpkg.Invoice, setIDs ...*invpkg.SetID) {
filteredAMPState := make(invpkg.AMPInvoiceState)
for _, setID := range setIDs {
if setID == nil {
return
}
ampState, ok := invoice.AMPState[*setID]
if ok {
filteredAMPState[*setID] = ampState
}
}
invoice.AMPState = filteredAMPState
}
// ampHTLCsMap is a map of AMP HTLCs affected by an invoice update. // ampHTLCsMap is a map of AMP HTLCs affected by an invoice update.
type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC type ampHTLCsMap map[invpkg.SetID]map[models.CircuitKey]*invpkg.InvoiceHTLC
@@ -1056,7 +1092,8 @@ func (d *DB) InvoicesSettledSince(_ context.Context, sinceSettleIndex uint64) (
// For each key found, we'll look up the actual // For each key found, we'll look up the actual
// invoice, then accumulate it into our return value. // invoice, then accumulate it into our return value.
invoice, err := fetchInvoice( invoice, err := fetchInvoice(
invoiceKey[:], invoices, setID, invoiceKey[:], invoices, []*invpkg.SetID{setID},
true,
) )
if err != nil { if err != nil {
return err return err
@@ -1485,7 +1522,7 @@ func fetchAmpSubInvoices(invoiceBucket kvdb.RBucket, invoiceNum []byte,
// specified by the invoice number. If the setID fields are set, then only the // specified by the invoice number. If the setID fields are set, then only the
// HTLC information pertaining to those set IDs is returned. // HTLC information pertaining to those set IDs is returned.
func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket, func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
setIDs ...*invpkg.SetID) (invpkg.Invoice, error) { setIDs []*invpkg.SetID, filterAMPState bool) (invpkg.Invoice, error) {
invoiceBytes := invoices.Get(invoiceNum) invoiceBytes := invoices.Get(invoiceNum)
if invoiceBytes == nil { if invoiceBytes == nil {
@@ -1518,6 +1555,10 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket,
log.Errorf("unable to fetch amp htlcs for inv "+ log.Errorf("unable to fetch amp htlcs for inv "+
"%v and setIDs %v: %w", invoiceNum, setIDs, err) "%v and setIDs %v: %w", invoiceNum, setIDs, err)
} }
if filterAMPState {
filterInvoiceAMPState(&invoice, setIDs...)
}
} }
return invoice, nil return invoice, nil
@@ -2163,7 +2204,7 @@ func (d *DB) DeleteCanceledInvoices(_ context.Context) error {
return nil return nil
} }
invoice, err := fetchInvoice(v, invoices) invoice, err := fetchInvoice(v, invoices, nil, false)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -266,6 +266,11 @@ that validate `ChannelAnnouncement` messages.
our health checker to correctly shut down LND if network partitioning occurs our health checker to correctly shut down LND if network partitioning occurs
towards the etcd cluster. towards the etcd cluster.
* [Fix](https://github.com/lightningnetwork/lnd/pull/9050) some inconsistencies
to make the native SQL invoice DB compatible with the KV implementation.
Furthermore fix a native SQL invoice issue where AMP subinvoice HTLCs are
sometimes updated incorrectly on settlement.
## Code Health ## Code Health
* [Move graph building and * [Move graph building and
@@ -282,6 +287,7 @@ that validate `ChannelAnnouncement` messages.
# Contributors (Alphabetical Order) # Contributors (Alphabetical Order)
* Alex Akselrod
* Andras Banki-Horvath * Andras Banki-Horvath
* bitromortac * bitromortac
* Bufo * Bufo

3
go.mod
View File

@@ -204,6 +204,9 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2
// allows us to specify that as an option. // allows us to specify that as an option.
replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.30.0-hex-display
// Temporary replace until the next version of sqldb is taged.
replace github.com/lightningnetwork/lnd/sqldb => ./sqldb
// If you change this please also update .github/pull_request_template.md, // If you change this please also update .github/pull_request_template.md,
// docs/INSTALL.md and GO_IMAGE in lnrpc/gen_protos_docker.sh. // docs/INSTALL.md and GO_IMAGE in lnrpc/gen_protos_docker.sh.
go 1.22.6 go 1.22.6

2
go.sum
View File

@@ -458,8 +458,6 @@ github.com/lightningnetwork/lnd/kvdb v1.4.10 h1:vK89IVv1oVH9ubQWU+EmoCQFeVRaC8kf
github.com/lightningnetwork/lnd/kvdb v1.4.10/go.mod h1:J2diNABOoII9UrMnxXS5w7vZwP7CA1CStrl8MnIrb3A= github.com/lightningnetwork/lnd/kvdb v1.4.10/go.mod h1:J2diNABOoII9UrMnxXS5w7vZwP7CA1CStrl8MnIrb3A=
github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI= github.com/lightningnetwork/lnd/queue v1.1.1 h1:99ovBlpM9B0FRCGYJo6RSFDlt8/vOkQQZznVb18iNMI=
github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4= github.com/lightningnetwork/lnd/queue v1.1.1/go.mod h1:7A6nC1Qrm32FHuhx/mi1cieAiBZo5O6l8IBIoQxvkz4=
github.com/lightningnetwork/lnd/sqldb v1.0.3 h1:zLfAwOvM+6+3+hahYO9Q3h8pVV0TghAR7iJ5YMLCd3I=
github.com/lightningnetwork/lnd/sqldb v1.0.3/go.mod h1:4cQOkdymlZ1znnjuRNvMoatQGJkRneTj2CoPSPaQhWo=
github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM= github.com/lightningnetwork/lnd/ticker v1.1.1 h1:J/b6N2hibFtC7JLV77ULQp++QLtCwT6ijJlbdiZFbSM=
github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA= github.com/lightningnetwork/lnd/ticker v1.1.1/go.mod h1:waPTRAAcwtu7Ji3+3k+u/xH5GHovTsCoSVpho0KDvdA=
github.com/lightningnetwork/lnd/tlv v1.2.6 h1:icvQG2yDr6k3ZuZzfRdG3EJp6pHurcuh3R6dg0gv/Mw= github.com/lightningnetwork/lnd/tlv v1.2.6 h1:icvQG2yDr6k3ZuZzfRdG3EJp6pHurcuh3R6dg0gv/Mw=

View File

@@ -10,6 +10,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
@@ -46,6 +47,9 @@ type SQLInvoiceQueries interface { //nolint:interfacebloat
GetInvoice(ctx context.Context, GetInvoice(ctx context.Context,
arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error) arg sqlc.GetInvoiceParams) ([]sqlc.Invoice, error)
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]sqlc.Invoice,
error)
GetInvoiceFeatures(ctx context.Context, GetInvoiceFeatures(ctx context.Context,
invoiceID int64) ([]sqlc.InvoiceFeature, error) invoiceID int64) ([]sqlc.InvoiceFeature, error)
@@ -343,7 +347,22 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
params.SetID = ref.SetID()[:] params.SetID = ref.SetID()[:]
} }
rows, err := db.GetInvoice(ctx, params) var (
rows []sqlc.Invoice
err error
)
// We need to split the query based on how we intend to look up the
// invoice. If only the set ID is given then we want to have an exact
// match on the set ID. If other fields are given, we want to match on
// those fields and the set ID but with a less strict join condition.
if params.Hash == nil && params.PaymentAddr == nil &&
params.SetID != nil {
rows, err = db.GetInvoiceBySetID(ctx, params.SetID)
} else {
rows, err = db.GetInvoice(ctx, params)
}
switch { switch {
case len(rows) == 0: case len(rows) == 0:
return nil, ErrInvoiceNotFound return nil, ErrInvoiceNotFound
@@ -351,8 +370,8 @@ func (i *SQLStore) fetchInvoice(ctx context.Context,
case len(rows) > 1: case len(rows) > 1:
// In case the reference is ambiguous, meaning it matches more // In case the reference is ambiguous, meaning it matches more
// than one invoice, we'll return an error. // than one invoice, we'll return an error.
return nil, fmt.Errorf("ambiguous invoice ref: %s", return nil, fmt.Errorf("ambiguous invoice ref: %s: %s",
ref.String()) ref.String(), spew.Sdump(rows))
case err != nil: case err != nil:
return nil, fmt.Errorf("unable to fetch invoice: %w", err) return nil, fmt.Errorf("unable to fetch invoice: %w", err)
@@ -906,8 +925,10 @@ func (i *SQLStore) QueryInvoices(ctx context.Context,
} }
if q.CreationDateEnd != 0 { if q.CreationDateEnd != 0 {
// We need to add 1 to the end date as we're
// checking less than the end date in SQL.
params.CreatedBefore = sqldb.SQLTime( params.CreatedBefore = sqldb.SQLTime(
time.Unix(q.CreationDateEnd, 0).UTC(), time.Unix(q.CreationDateEnd+1, 0).UTC(),
) )
} }
@@ -1116,6 +1137,9 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
SetID: setID[:], SetID: setID[:],
HtlcID: int64(circuitKey.HtlcID), HtlcID: int64(circuitKey.HtlcID),
Preimage: preimage[:], Preimage: preimage[:],
ChanID: strconv.FormatUint(
circuitKey.ChanID.ToUint64(), 10,
),
}, },
) )
if err != nil { if err != nil {
@@ -1280,6 +1304,13 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
return err return err
} }
if settleIndex.Valid {
updatedState := s.invoice.AMPState[setID]
updatedState.SettleIndex = uint64(settleIndex.Int64)
updatedState.SettleDate = s.updateTime.UTC()
s.invoice.AMPState[setID] = updatedState
}
return nil return nil
} }
@@ -1298,13 +1329,24 @@ func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
// invoice and is therefore atomic. The fields to update are controlled by the // invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback. // supplied callback.
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef, func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
_ *SetID, callback InvoiceUpdateCallback) ( setID *SetID, callback InvoiceUpdateCallback) (
*Invoice, error) { *Invoice, error) {
var updatedInvoice *Invoice var updatedInvoice *Invoice
txOpt := SQLInvoiceQueriesTxOptions{readOnly: false} txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error { txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
if setID != nil {
// Make sure to use the set ID if this is an AMP update.
var setIDBytes [32]byte
copy(setIDBytes[:], setID[:])
ref.setID = &setIDBytes
// If we're updating an AMP invoice, we'll also only
// need to fetch the HTLCs for the given set ID.
ref.refModifier = HtlcSetOnlyModifier
}
invoice, err := i.fetchInvoice(ctx, db, ref) invoice, err := i.fetchInvoice(ctx, db, ref)
if err != nil { if err != nil {
return err return err

View File

@@ -260,7 +260,8 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
invoiceNtfn := ht.ReceiveInvoiceUpdate(invSubscription) invoiceNtfn := ht.ReceiveInvoiceUpdate(invSubscription)
// The notification should signal that the invoice is now settled, and // The notification should signal that the invoice is now settled, and
// should also include the set ID, and show the proper amount paid. // should also include the set ID, show the proper amount paid, and have
// the correct settle index and time.
require.True(ht, invoiceNtfn.Settled) require.True(ht, invoiceNtfn.Settled)
require.Equal(ht, lnrpc.Invoice_SETTLED, invoiceNtfn.State) require.Equal(ht, lnrpc.Invoice_SETTLED, invoiceNtfn.State)
require.Equal(ht, paymentAmt, int(invoiceNtfn.AmtPaidSat)) require.Equal(ht, paymentAmt, int(invoiceNtfn.AmtPaidSat))
@@ -270,6 +271,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
firstSetID, _ = hex.DecodeString(setIDStr) firstSetID, _ = hex.DecodeString(setIDStr)
require.Equal(ht, lnrpc.InvoiceHTLCState_SETTLED, require.Equal(ht, lnrpc.InvoiceHTLCState_SETTLED,
ampState.State) ampState.State)
require.GreaterOrEqual(ht, ampState.SettleTime,
rpcInvoice.CreationDate)
require.Equal(ht, uint64(1), ampState.SettleIndex)
} }
// Pay the invoice again, we should get another notification that Dave // Pay the invoice again, we should get another notification that Dave
@@ -299,9 +303,9 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
// return the "projected" sub-invoice for a given setID. // return the "projected" sub-invoice for a given setID.
require.Equal(ht, 1, len(invoiceNtfn.Htlcs)) require.Equal(ht, 1, len(invoiceNtfn.Htlcs))
// However the AMP state index should show that there've been two // The AMP state should also be restricted to a single entry for the
// repeated payments to this invoice so far. // "projected" sub-invoice.
require.Equal(ht, 2, len(invoiceNtfn.AmpInvoiceState)) require.Equal(ht, 1, len(invoiceNtfn.AmpInvoiceState))
// Now we'll look up the invoice using the new LookupInvoice2 RPC call // Now we'll look up the invoice using the new LookupInvoice2 RPC call
// by the set ID of each of the invoices. // by the set ID of each of the invoices.
@@ -360,7 +364,7 @@ func testSendPaymentAMPInvoiceRepeat(ht *lntest.HarnessTest) {
// through. // through.
backlogInv := ht.ReceiveInvoiceUpdate(invSub2) backlogInv := ht.ReceiveInvoiceUpdate(invSub2)
require.Equal(ht, 1, len(backlogInv.Htlcs)) require.Equal(ht, 1, len(backlogInv.Htlcs))
require.Equal(ht, 2, len(backlogInv.AmpInvoiceState)) require.Equal(ht, 1, len(backlogInv.AmpInvoiceState))
require.True(ht, backlogInv.Settled) require.True(ht, backlogInv.Settled)
require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat)) require.Equal(ht, paymentAmt*2, int(backlogInv.AmtPaidSat))
} }

View File

@@ -268,15 +268,16 @@ func (q *Queries) InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubI
const updateAMPSubInvoiceHTLCPreimage = `-- name: UpdateAMPSubInvoiceHTLCPreimage :execresult const updateAMPSubInvoiceHTLCPreimage = `-- name: UpdateAMPSubInvoiceHTLCPreimage :execresult
UPDATE amp_sub_invoice_htlcs AS a UPDATE amp_sub_invoice_htlcs AS a
SET preimage = $4 SET preimage = $5
WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = ( WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = (
SELECT id FROM invoice_htlcs AS i WHERE i.htlc_id = $3 SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4
) )
` `
type UpdateAMPSubInvoiceHTLCPreimageParams struct { type UpdateAMPSubInvoiceHTLCPreimageParams struct {
InvoiceID int64 InvoiceID int64
SetID []byte SetID []byte
ChanID string
HtlcID int64 HtlcID int64
Preimage []byte Preimage []byte
} }
@@ -285,6 +286,7 @@ func (q *Queries) UpdateAMPSubInvoiceHTLCPreimage(ctx context.Context, arg Updat
return q.db.ExecContext(ctx, updateAMPSubInvoiceHTLCPreimage, return q.db.ExecContext(ctx, updateAMPSubInvoiceHTLCPreimage,
arg.InvoiceID, arg.InvoiceID,
arg.SetID, arg.SetID,
arg.ChanID,
arg.HtlcID, arg.HtlcID,
arg.Preimage, arg.Preimage,
) )

View File

@@ -78,7 +78,7 @@ WHERE (
created_at >= $6 OR created_at >= $6 OR
$6 IS NULL $6 IS NULL
) AND ( ) AND (
created_at <= $7 OR created_at < $7 OR
$7 IS NULL $7 IS NULL
) AND ( ) AND (
CASE CASE
@@ -170,21 +170,22 @@ const getInvoice = `-- name: GetInvoice :many
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
FROM invoices i FROM invoices i
LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id LEFT JOIN amp_sub_invoices a
ON i.id = a.invoice_id
AND (
a.set_id = $1 OR $1 IS NULL
)
WHERE ( WHERE (
i.id = $1 OR i.id = $2 OR
$1 IS NULL
) AND (
i.hash = $2 OR
$2 IS NULL $2 IS NULL
) AND ( ) AND (
i.preimage = $3 OR i.hash = $3 OR
$3 IS NULL $3 IS NULL
) AND ( ) AND (
i.payment_addr = $4 OR i.preimage = $4 OR
$4 IS NULL $4 IS NULL
) AND ( ) AND (
a.set_id = $5 OR i.payment_addr = $5 OR
$5 IS NULL $5 IS NULL
) )
GROUP BY i.id GROUP BY i.id
@@ -192,11 +193,11 @@ LIMIT 2
` `
type GetInvoiceParams struct { type GetInvoiceParams struct {
SetID []byte
AddIndex sql.NullInt64 AddIndex sql.NullInt64
Hash []byte Hash []byte
Preimage []byte Preimage []byte
PaymentAddr []byte PaymentAddr []byte
SetID []byte
} }
// This method may return more than one invoice if filter using multiple fields // This method may return more than one invoice if filter using multiple fields
@@ -204,11 +205,11 @@ type GetInvoiceParams struct {
// we bubble up an error in those cases. // we bubble up an error in those cases.
func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) { func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) {
rows, err := q.db.QueryContext(ctx, getInvoice, rows, err := q.db.QueryContext(ctx, getInvoice,
arg.SetID,
arg.AddIndex, arg.AddIndex,
arg.Hash, arg.Hash,
arg.Preimage, arg.Preimage,
arg.PaymentAddr, arg.PaymentAddr,
arg.SetID,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -250,6 +251,55 @@ func (q *Queries) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoi
return items, nil return items, nil
} }
const getInvoiceBySetID = `-- name: GetInvoiceBySetID :many
SELECT i.id, i.hash, i.preimage, i.settle_index, i.settled_at, i.memo, i.amount_msat, i.cltv_delta, i.expiry, i.payment_addr, i.payment_request, i.payment_request_hash, i.state, i.amount_paid_msat, i.is_amp, i.is_hodl, i.is_keysend, i.created_at
FROM invoices i
INNER JOIN amp_sub_invoices a
ON i.id = a.invoice_id AND a.set_id = $1
`
func (q *Queries) GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error) {
rows, err := q.db.QueryContext(ctx, getInvoiceBySetID, setID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Invoice
for rows.Next() {
var i Invoice
if err := rows.Scan(
&i.ID,
&i.Hash,
&i.Preimage,
&i.SettleIndex,
&i.SettledAt,
&i.Memo,
&i.AmountMsat,
&i.CltvDelta,
&i.Expiry,
&i.PaymentAddr,
&i.PaymentRequest,
&i.PaymentRequestHash,
&i.State,
&i.AmountPaidMsat,
&i.IsAmp,
&i.IsHodl,
&i.IsKeysend,
&i.CreatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getInvoiceFeatures = `-- name: GetInvoiceFeatures :many const getInvoiceFeatures = `-- name: GetInvoiceFeatures :many
SELECT feature, invoice_id SELECT feature, invoice_id
FROM invoice_features FROM invoice_features

View File

@@ -21,6 +21,7 @@ type Querier interface {
// from different invoices. It is the caller's responsibility to ensure that // from different invoices. It is the caller's responsibility to ensure that
// we bubble up an error in those cases. // we bubble up an error in those cases.
GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error) GetInvoice(ctx context.Context, arg GetInvoiceParams) ([]Invoice, error)
GetInvoiceBySetID(ctx context.Context, setID []byte) ([]Invoice, error)
GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error) GetInvoiceFeatures(ctx context.Context, invoiceID int64) ([]InvoiceFeature, error)
GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error) GetInvoiceHTLCCustomRecords(ctx context.Context, invoiceID int64) ([]GetInvoiceHTLCCustomRecordsRow, error)
GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error) GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error)

View File

@@ -61,7 +61,7 @@ WHERE (
-- name: UpdateAMPSubInvoiceHTLCPreimage :execresult -- name: UpdateAMPSubInvoiceHTLCPreimage :execresult
UPDATE amp_sub_invoice_htlcs AS a UPDATE amp_sub_invoice_htlcs AS a
SET preimage = $4 SET preimage = $5
WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = ( WHERE a.invoice_id = $1 AND a.set_id = $2 AND a.htlc_id = (
SELECT id FROM invoice_htlcs AS i WHERE i.htlc_id = $3 SELECT id FROM invoice_htlcs AS i WHERE i.chan_id = $3 AND i.htlc_id = $4
); );

View File

@@ -26,7 +26,11 @@ WHERE invoice_id = $1;
-- name: GetInvoice :many -- name: GetInvoice :many
SELECT i.* SELECT i.*
FROM invoices i FROM invoices i
LEFT JOIN amp_sub_invoices a on i.id = a.invoice_id LEFT JOIN amp_sub_invoices a
ON i.id = a.invoice_id
AND (
a.set_id = sqlc.narg('set_id') OR sqlc.narg('set_id') IS NULL
)
WHERE ( WHERE (
i.id = sqlc.narg('add_index') OR i.id = sqlc.narg('add_index') OR
sqlc.narg('add_index') IS NULL sqlc.narg('add_index') IS NULL
@@ -39,13 +43,16 @@ WHERE (
) AND ( ) AND (
i.payment_addr = sqlc.narg('payment_addr') OR i.payment_addr = sqlc.narg('payment_addr') OR
sqlc.narg('payment_addr') IS NULL sqlc.narg('payment_addr') IS NULL
) AND (
a.set_id = sqlc.narg('set_id') OR
sqlc.narg('set_id') IS NULL
) )
GROUP BY i.id GROUP BY i.id
LIMIT 2; LIMIT 2;
-- name: GetInvoiceBySetID :many
SELECT i.*
FROM invoices i
INNER JOIN amp_sub_invoices a
ON i.id = a.invoice_id AND a.set_id = $1;
-- name: FilterInvoices :many -- name: FilterInvoices :many
SELECT SELECT
invoices.* invoices.*
@@ -69,7 +76,7 @@ WHERE (
created_at >= sqlc.narg('created_after') OR created_at >= sqlc.narg('created_after') OR
sqlc.narg('created_after') IS NULL sqlc.narg('created_after') IS NULL
) AND ( ) AND (
created_at <= sqlc.narg('created_before') OR created_at < sqlc.narg('created_before') OR
sqlc.narg('created_before') IS NULL sqlc.narg('created_before') IS NULL
) AND ( ) AND (
CASE CASE