sqldb+invoices: move SQL invoice store impl to invoices package

This commit is contained in:
Andras Banki-Horvath
2024-03-29 10:54:00 +01:00
parent 6d316ef56f
commit 7f5c8219ef
4 changed files with 169 additions and 165 deletions

View File

@ -1053,12 +1053,12 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
executor := sqldb.NewTransactionExecutor(
dbs.NativeSQLStore,
func(tx *sql.Tx) sqldb.InvoiceQueries {
func(tx *sql.Tx) invoices.SQLInvoiceQueries {
return dbs.NativeSQLStore.WithTx(tx)
},
)
dbs.InvoiceDB = sqldb.NewInvoiceStore(
dbs.InvoiceDB = invoices.NewSQLStore(
executor, clock.NewDefaultClock(),
)
} else {

View File

@ -136,14 +136,14 @@ func TestInvoiceRegistry(t *testing.T) {
}
executor := sqldb.NewTransactionExecutor(
db, func(tx *sql.Tx) sqldb.InvoiceQueries {
db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries {
return db.WithTx(tx)
},
)
testClock := clock.NewTestClock(testNow)
return sqldb.NewInvoiceStore(executor, testClock), testClock
return invpkg.NewSQLStore(executor, testClock), testClock
}
for _, test := range testList {

View File

@ -240,14 +240,14 @@ func TestInvoices(t *testing.T) {
}
executor := sqldb.NewTransactionExecutor(
db, func(tx *sql.Tx) sqldb.InvoiceQueries {
db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries {
return db.WithTx(tx)
},
)
testClock := clock.NewTestClock(testNow)
return sqldb.NewInvoiceStore(executor, testClock)
return invpkg.NewSQLStore(executor, testClock)
}
for _, test := range testList {

View File

@ -1,4 +1,4 @@
package sqldb
package invoices
import (
"context"
@ -12,10 +12,10 @@ import (
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
invpkg "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
)
@ -25,9 +25,9 @@ const (
queryPaginationLimit = 100
)
// InvoiceQueries is an interface that defines the set of operations that can be
// executed against the invoice database.
type InvoiceQueries interface { //nolint:interfacebloat
// SQLInvoiceQueries is an interface that defines the set of operations that can
// be executed against the invoice SQL database.
type SQLInvoiceQueries interface { //nolint:interfacebloat
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
error)
@ -117,11 +117,11 @@ type InvoiceQueries interface { //nolint:interfacebloat
arg sqlc.OnAMPSubInvoiceSettledParams) error
}
var _ invpkg.InvoiceDB = (*InvoiceStore)(nil)
var _ InvoiceDB = (*SQLStore)(nil)
// InvoiceQueriesTxOptions defines the set of db txn options the InvoiceQueries
// understands.
type InvoiceQueriesTxOptions struct {
// SQLInvoiceQueriesTxOptions defines the set of db txn options the
// SQLInvoiceQueries understands.
type SQLInvoiceQueriesTxOptions struct {
// readOnly governs if a read only transaction is needed or not.
readOnly bool
}
@ -129,37 +129,37 @@ type InvoiceQueriesTxOptions struct {
// ReadOnly returns true if the transaction should be read only.
//
// NOTE: This implements the TxOptions.
func (a *InvoiceQueriesTxOptions) ReadOnly() bool {
func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
return a.readOnly
}
// NewInvoiceQueryReadTx creates a new read transaction option set.
func NewInvoiceQueryReadTx() InvoiceQueriesTxOptions {
return InvoiceQueriesTxOptions{
// NewSQLInvoiceQueryReadTx creates a new read transaction option set.
func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
return SQLInvoiceQueriesTxOptions{
readOnly: true,
}
}
// BatchedInvoiceQueries is a version of the InvoiceQueries that's capable of
// batched database operations.
type BatchedInvoiceQueries interface {
InvoiceQueries
// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
// of batched database operations.
type BatchedSQLInvoiceQueries interface {
SQLInvoiceQueries
BatchedTx[InvoiceQueries]
sqldb.BatchedTx[SQLInvoiceQueries]
}
// InvoiceStore represents a storage backend.
type InvoiceStore struct {
db BatchedInvoiceQueries
// SQLStore represents a storage backend.
type SQLStore struct {
db BatchedSQLInvoiceQueries
clock clock.Clock
}
// NewInvoiceStore creates a new InvoiceStore instance given a open
// BatchedInvoiceQueries storage backend.
func NewInvoiceStore(db BatchedInvoiceQueries,
clock clock.Clock) *InvoiceStore {
// NewSQLStore creates a new SQLStore instance given a open
// BatchedSQLInvoiceQueries storage backend.
func NewSQLStore(db BatchedSQLInvoiceQueries,
clock clock.Clock) *SQLStore {
return &InvoiceStore{
return &SQLStore{
db: db,
clock: clock,
}
@ -171,17 +171,17 @@ func NewInvoiceStore(db BatchedInvoiceQueries,
// duplicate payment hashes.
//
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
func (i *InvoiceStore) AddInvoice(ctx context.Context,
newInvoice *invpkg.Invoice, paymentHash lntypes.Hash) (uint64, error) {
func (i *SQLStore) AddInvoice(ctx context.Context,
newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
// Make sure this is a valid invoice before trying to store it in our
// DB.
if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil {
if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
return 0, err
}
var (
writeTxOpts InvoiceQueriesTxOptions
writeTxOpts SQLInvoiceQueriesTxOptions
invoiceID int64
)
@ -193,16 +193,18 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
paymentRequestHash = h.Sum(nil)
}
err := i.db.ExecTx(ctx, &writeTxOpts, func(db InvoiceQueries) error {
err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
params := sqlc.InsertInvoiceParams{
Hash: paymentHash[:],
Memo: SQLStr(string(newInvoice.Memo)),
Memo: sqldb.SQLStr(string(newInvoice.Memo)),
AmountMsat: int64(newInvoice.Terms.Value),
// Note: BOLT12 invoices don't have a final cltv delta.
CltvDelta: SQLInt32(newInvoice.Terms.FinalCltvDelta),
Expiry: int32(newInvoice.Terms.Expiry),
CltvDelta: sqldb.SQLInt32(
newInvoice.Terms.FinalCltvDelta,
),
Expiry: int32(newInvoice.Terms.Expiry),
// Note: keysend invoices don't have a payment request.
PaymentRequest: SQLStr(string(
PaymentRequest: sqldb.SQLStr(string(
newInvoice.PaymentRequest),
),
PaymentRequestHash: paymentRequestHash,
@ -218,7 +220,7 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
// HODL invoices.
if newInvoice.Terms.PaymentPreimage != nil {
preimage := *newInvoice.Terms.PaymentPreimage
if preimage == invpkg.UnknownPreimage {
if preimage == UnknownPreimage {
return errors.New("cannot use all-zeroes " +
"preimage")
}
@ -226,7 +228,7 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
}
// Some non MPP payments may have the default (invalid) value.
if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr {
if newInvoice.Terms.PaymentAddr != BlankPayAddr {
params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
}
@ -259,11 +261,11 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
})
})
if err != nil {
mappedSQLErr := MapSQLError(err)
var uniqueConstraintErr *ErrSQLUniqueConstraintViolation
mappedSQLErr := sqldb.MapSQLError(err)
var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
if errors.As(mappedSQLErr, &uniqueConstraintErr) {
// Add context to unique constraint errors.
return 0, invpkg.ErrDuplicateInvoice
return 0, ErrDuplicateInvoice
}
return 0, fmt.Errorf("unable to add invoice(%v): %w",
@ -277,15 +279,15 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
// fetchInvoice fetches the common invoice data and the AMP state for the
// invoice with the given reference.
func (i *InvoiceStore) fetchInvoice(ctx context.Context,
db InvoiceQueries, ref invpkg.InvoiceRef) (*invpkg.Invoice, error) {
func (i *SQLStore) fetchInvoice(ctx context.Context,
db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) {
if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
return nil, invpkg.ErrInvoiceNotFound
return nil, ErrInvoiceNotFound
}
var (
invoice *invpkg.Invoice
invoice *Invoice
params sqlc.GetInvoiceParams
)
@ -300,7 +302,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
// all. Only allow lookups for payment address if it is not a blank
// payment address, which is a special-cased value for legacy keysend
// invoices.
if ref.PayAddr() != nil && *ref.PayAddr() != invpkg.BlankPayAddr {
if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
params.PaymentAddr = ref.PayAddr()[:]
}
@ -313,7 +315,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
rows, err := db.GetInvoice(ctx, params)
switch {
case len(rows) == 0:
return nil, invpkg.ErrInvoiceNotFound
return nil, ErrInvoiceNotFound
case len(rows) > 1:
// In case the reference is ambiguous, meaning it matches more
@ -333,12 +335,12 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
// Now that we got the invoice itself, fetch the HTLCs as requested by
// the modifier.
switch ref.Modifier() {
case invpkg.DefaultModifier:
case DefaultModifier:
// By default we'll fetch all AMP HTLCs.
setID = nil
fetchAmpHtlcs = true
case invpkg.HtlcSetOnlyModifier:
case HtlcSetOnlyModifier:
// In this case we'll fetch all AMP HTLCs for the
// specified set id.
if ref.SetID() == nil {
@ -349,7 +351,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
setID = ref.SetID()
fetchAmpHtlcs = true
case invpkg.HtlcSetBlankModifier:
case HtlcSetBlankModifier:
// No need to fetch any HTLCs.
setID = nil
fetchAmpHtlcs = false
@ -377,9 +379,9 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
// well.
//
//nolint:funlen
func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
setID *[32]byte, fetchHtlcs bool) (invpkg.AMPInvoiceState,
invpkg.HTLCSet, error) {
func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
HTLCSet, error) {
var paramSetID []byte
if setID != nil {
@ -398,7 +400,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
return nil, nil, err
}
ampState := make(map[invpkg.SetID]invpkg.InvoiceStateAMP)
ampState := make(map[SetID]InvoiceStateAMP)
for _, row := range ampInvoiceRows {
var rowSetID [32]byte
@ -413,8 +415,8 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
}
copy(rowSetID[:], row.SetID)
ampState[rowSetID] = invpkg.InvoiceStateAMP{
State: invpkg.HtlcState(row.State),
ampState[rowSetID] = InvoiceStateAMP{
State: HtlcState(row.State),
SettleIndex: uint64(row.SettleIndex.Int64),
SettleDate: settleDate,
InvoiceKeys: make(map[models.CircuitKey]struct{}),
@ -457,7 +459,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
return nil, nil, err
}
ampHtlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC)
ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
for _, row := range ampHtlcRows {
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
if err != nil {
@ -473,17 +475,17 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
htlcID := uint64(row.HtlcID)
circuitKey := invpkg.CircuitKey{
circuitKey := CircuitKey{
ChanID: chanID,
HtlcID: htlcID,
}
htlc := &invpkg.InvoiceHTLC{
htlc := &InvoiceHTLC{
Amt: lnwire.MilliSatoshi(row.AmountMsat),
AcceptHeight: uint32(row.AcceptHeight),
AcceptTime: row.AcceptTime.Local(),
Expiry: uint32(row.ExpiryHeight),
State: invpkg.HtlcState(row.State),
State: HtlcState(row.State),
}
if row.TotalMppMsat.Valid {
@ -522,7 +524,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
rootShare, setID, uint32(row.ChildIndex),
)
htlc.AMP = &invpkg.InvoiceHtlcAMPData{
htlc.AMP = &InvoiceHtlcAMPData{
Record: *ampRecord,
}
@ -564,7 +566,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
invoiceKeys[key] = struct{}{}
if htlc.State != invpkg.HtlcStateCanceled { //nolint: lll
if htlc.State != HtlcStateCanceled { //nolint: lll
amtPaid += htlc.Amt
}
}
@ -584,22 +586,22 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
// error.
func (i *InvoiceStore) LookupInvoice(ctx context.Context,
ref invpkg.InvoiceRef) (invpkg.Invoice, error) {
func (i *SQLStore) LookupInvoice(ctx context.Context,
ref InvoiceRef) (Invoice, error) {
var (
invoice *invpkg.Invoice
invoice *Invoice
err error
)
readTxOpt := NewInvoiceQueryReadTx()
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
readTxOpt := NewSQLInvoiceQueryReadTx()
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoice, err = i.fetchInvoice(ctx, db, ref)
return err
})
if txErr != nil {
return invpkg.Invoice{}, txErr
return Invoice{}, txErr
}
return *invoice, nil
@ -608,14 +610,14 @@ func (i *InvoiceStore) LookupInvoice(ctx context.Context,
// FetchPendingInvoices returns all the invoices that are currently in a
// "pending" state. An invoice is pending if it has been created but not yet
// settled or canceled.
func (i *InvoiceStore) FetchPendingInvoices(ctx context.Context) (
map[lntypes.Hash]invpkg.Invoice, error) {
func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
map[lntypes.Hash]Invoice, error) {
var invoices map[lntypes.Hash]invpkg.Invoice
var invoices map[lntypes.Hash]Invoice
readTxOpt := NewInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
invoices = make(map[lntypes.Hash]invpkg.Invoice)
readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoices = make(map[lntypes.Hash]Invoice)
limit := queryPaginationLimit
return queryWithLimit(func(offset int) (int, error) {
@ -661,24 +663,24 @@ func (i *InvoiceStore) FetchPendingInvoices(ctx context.Context) (
//
// NOTE: The index starts from 1. As a result we enforce that specifying a value
// below the starting index value is a noop.
func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
[]invpkg.Invoice, error) {
func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
[]Invoice, error) {
var invoices []invpkg.Invoice
var invoices []Invoice
if idx == 0 {
return invoices, nil
}
readTxOpt := NewInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoices = nil
settleIdx := idx
limit := queryPaginationLimit
err := queryWithLimit(func(offset int) (int, error) {
params := sqlc.FilterInvoicesParams{
SettleIndexGet: SQLInt64(settleIdx + 1),
SettleIndexGet: sqldb.SQLInt64(settleIdx + 1),
NumLimit: int32(limit),
NumOffset: int32(offset),
}
@ -715,7 +717,7 @@ func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
// the provided index.
ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
ctx, sqlc.FetchSettledAMPSubInvoicesParams{
SettleIndexGet: SQLInt64(idx + 1),
SettleIndexGet: sqldb.SQLInt64(idx + 1),
},
)
if err != nil {
@ -775,24 +777,24 @@ func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
//
// NOTE: The index starts from 1. As a result we enforce that specifying a value
// below the starting index value is a noop.
func (i *InvoiceStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
[]invpkg.Invoice, error) {
func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
[]Invoice, error) {
var result []invpkg.Invoice
var result []Invoice
if idx == 0 {
return result, nil
}
readTxOpt := NewInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
result = nil
addIdx := idx
limit := queryPaginationLimit
return queryWithLimit(func(offset int) (int, error) {
params := sqlc.FilterInvoicesParams{
AddIndexGet: SQLInt64(addIdx + 1),
AddIndexGet: sqldb.SQLInt64(addIdx + 1),
NumLimit: int32(limit),
NumOffset: int32(offset),
}
@ -831,18 +833,18 @@ func (i *InvoiceStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
// QueryInvoices allows a caller to query the invoice database for invoices
// within the specified add index range.
func (i *InvoiceStore) QueryInvoices(ctx context.Context,
q invpkg.InvoiceQuery) (invpkg.InvoiceSlice, error) {
func (i *SQLStore) QueryInvoices(ctx context.Context,
q InvoiceQuery) (InvoiceSlice, error) {
var invoices []invpkg.Invoice
var invoices []Invoice
if q.NumMaxInvoices == 0 {
return invpkg.InvoiceSlice{}, fmt.Errorf("max invoices must " +
return InvoiceSlice{}, fmt.Errorf("max invoices must " +
"be non-zero")
}
readTxOpt := NewInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error {
readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoices = nil
limit := queryPaginationLimit
@ -856,7 +858,7 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
if !q.Reversed {
// The invoice with index offset id must not be
// included in the results.
params.AddIndexGet = SQLInt64(
params.AddIndexGet = sqldb.SQLInt64(
q.IndexOffset + uint64(offset) + 1,
)
}
@ -867,13 +869,13 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
// If the index offset was not set, we want to
// fetch from the lastest invoice.
if idx == 0 {
params.AddIndexLet = SQLInt64(
params.AddIndexLet = sqldb.SQLInt64(
int64(math.MaxInt64),
)
} else {
// The invoice with index offset id must
// not be included in the results.
params.AddIndexLet = SQLInt64(
params.AddIndexLet = sqldb.SQLInt64(
idx - int32(offset) - 1,
)
}
@ -882,13 +884,13 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
}
if q.CreationDateStart != 0 {
params.CreatedAfter = SQLTime(
params.CreatedAfter = sqldb.SQLTime(
time.Unix(q.CreationDateStart, 0).UTC(),
)
}
if q.CreationDateEnd != 0 {
params.CreatedBefore = SQLTime(
params.CreatedBefore = sqldb.SQLTime(
time.Unix(q.CreationDateEnd, 0).UTC(),
)
}
@ -919,12 +921,12 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
}, limit)
})
if err != nil {
return invpkg.InvoiceSlice{}, fmt.Errorf("unable to query "+
return InvoiceSlice{}, fmt.Errorf("unable to query "+
"invoices: %w", err)
}
if len(invoices) == 0 {
return invpkg.InvoiceSlice{
return InvoiceSlice{
InvoiceQuery: q,
}, nil
}
@ -941,7 +943,7 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
}
}
res := invpkg.InvoiceSlice{
res := InvoiceSlice{
InvoiceQuery: q,
Invoices: invoices,
FirstIndexOffset: invoices[0].AddIndex,
@ -954,15 +956,15 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
// a SQL database as the backend.
type sqlInvoiceUpdater struct {
db InvoiceQueries
db SQLInvoiceQueries
ctx context.Context //nolint:containedctx
invoice *invpkg.Invoice
invoice *Invoice
updateTime time.Time
}
// AddHtlc adds a new htlc to the invoice.
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
newHtlc *invpkg.InvoiceHTLC) error {
newHtlc *InvoiceHTLC) error {
htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
s.ctx, sqlc.InsertInvoiceHTLCParams{
@ -1012,10 +1014,10 @@ func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
},
)
if err != nil {
mappedSQLErr := MapSQLError(err)
var uniqueConstraintErr *ErrSQLUniqueConstraintViolation
mappedSQLErr := sqldb.MapSQLError(err)
var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:lll
if errors.As(mappedSQLErr, &uniqueConstraintErr) {
return invpkg.ErrDuplicateSetID{
return ErrDuplicateSetID{
SetID: setID,
}
}
@ -1072,7 +1074,7 @@ func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
// ResolveHtlc marks an htlc as resolved with the given state.
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
state invpkg.HtlcState, resolveTime time.Time) error {
state HtlcState, resolveTime time.Time) error {
return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
HtlcID: int64(circuitKey.HtlcID),
@ -1081,7 +1083,7 @@ func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
),
InvoiceID: int64(s.invoice.AddIndex),
State: int16(state),
ResolveTime: SQLTime(resolveTime.UTC()),
ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
})
}
@ -1107,7 +1109,7 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
return err
}
if rowsAffected == 0 {
return invpkg.ErrInvoiceNotFound
return ErrInvoiceNotFound
}
return nil
@ -1115,7 +1117,7 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
// UpdateInvoiceState updates the invoice state to the new state.
func (s *sqlInvoiceUpdater) UpdateInvoiceState(
newState invpkg.ContractState, preimage *lntypes.Preimage) error {
newState ContractState, preimage *lntypes.Preimage) error {
var (
settleIndex sql.NullInt64
@ -1123,16 +1125,16 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState(
)
switch newState {
case invpkg.ContractSettled:
case ContractSettled:
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
if err != nil {
return err
}
settleIndex = SQLInt64(nextSettleIndex)
settleIndex = sqldb.SQLInt64(nextSettleIndex)
// If the invoice is settled, we'll also update the settle time.
settledAt = SQLTime(s.updateTime.UTC())
settledAt = sqldb.SQLTime(s.updateTime.UTC())
err = s.db.OnInvoiceSettled(
s.ctx, sqlc.OnInvoiceSettledParams{
@ -1144,7 +1146,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState(
return err
}
case invpkg.ContractCanceled:
case ContractCanceled:
err := s.db.OnInvoiceCanceled(
s.ctx, sqlc.OnInvoiceCanceledParams{
AddedAt: s.updateTime.UTC(),
@ -1177,7 +1179,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState(
}
if rowsAffected == 0 {
return invpkg.ErrInvoiceNotFound
return ErrInvoiceNotFound
}
if settleIndex.Valid {
@ -1205,7 +1207,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
// UpdateAmpState updates the state of the AMP sub invoice identified by the
// setID.
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
newState invpkg.InvoiceStateAMP, _ models.CircuitKey) error {
newState InvoiceStateAMP, _ models.CircuitKey) error {
var (
settleIndex sql.NullInt64
@ -1213,16 +1215,16 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
)
switch newState.State {
case invpkg.HtlcStateSettled:
case HtlcStateSettled:
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
if err != nil {
return err
}
settleIndex = SQLInt64(nextSettleIndex)
settleIndex = sqldb.SQLInt64(nextSettleIndex)
// If the invoice is settled, we'll also update the settle time.
settledAt = SQLTime(s.updateTime.UTC())
settledAt = sqldb.SQLTime(s.updateTime.UTC())
err = s.db.OnAMPSubInvoiceSettled(
s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
@ -1235,7 +1237,7 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
return err
}
case invpkg.HtlcStateCanceled:
case HtlcStateCanceled:
err := s.db.OnAMPSubInvoiceCanceled(
s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
AddedAt: s.updateTime.UTC(),
@ -1266,7 +1268,7 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
// Finalize finalizes the update before it is written to the database. Note that
// we don't use this directly in the SQL implementation, so the function is just
// a stub.
func (s *sqlInvoiceUpdater) Finalize(_ invpkg.UpdateType) error {
func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
return nil
}
@ -1277,14 +1279,14 @@ func (s *sqlInvoiceUpdater) Finalize(_ invpkg.UpdateType) error {
// The update is performed inside the same database transaction that fetches the
// invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback.
func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
_ *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) (
*invpkg.Invoice, error) {
func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
_ *SetID, callback InvoiceUpdateCallback) (
*Invoice, error) {
var updatedInvoice *invpkg.Invoice
var updatedInvoice *Invoice
txOpt := InvoiceQueriesTxOptions{readOnly: false}
txErr := i.db.ExecTx(ctx, &txOpt, func(db InvoiceQueries) error {
txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
invoice, err := i.fetchInvoice(ctx, db, ref)
if err != nil {
return err
@ -1299,7 +1301,7 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
}
payHash := ref.PayHash()
updatedInvoice, err = invpkg.UpdateInvoice(
updatedInvoice, err = UpdateInvoice(
payHash, invoice, updateTime, callback, updater,
)
@ -1308,7 +1310,7 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
if txErr != nil {
// If the invoice is already settled, we'll return the
// (unchanged) invoice and the ErrInvoiceAlreadySettled error.
if errors.Is(txErr, invpkg.ErrInvoiceAlreadySettled) {
if errors.Is(txErr, ErrInvoiceAlreadySettled) {
return updatedInvoice, txErr
}
@ -1320,8 +1322,8 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
// DeleteInvoice attempts to delete the passed invoices and all their related
// data from the database in one transaction.
func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
invoicesToDelete []invpkg.InvoiceDeleteRef) error {
func (i *SQLStore) DeleteInvoice(ctx context.Context,
invoicesToDelete []InvoiceDeleteRef) error {
// All the InvoiceDeleteRef instances include the add index of the
// invoice. The rest was added to ensure that the invoices were deleted
@ -1334,15 +1336,17 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
}
}
var writeTxOpt InvoiceQueriesTxOptions
err := i.db.ExecTx(ctx, &writeTxOpt, func(db InvoiceQueries) error {
var writeTxOpt SQLInvoiceQueriesTxOptions
err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
for _, ref := range invoicesToDelete {
params := sqlc.DeleteInvoiceParams{
AddIndex: SQLInt64(ref.AddIndex),
AddIndex: sqldb.SQLInt64(ref.AddIndex),
}
if ref.SettleIndex != 0 {
params.SettleIndex = SQLInt64(ref.SettleIndex)
params.SettleIndex = sqldb.SQLInt64(
ref.SettleIndex,
)
}
if ref.PayHash != lntypes.ZeroHash {
@ -1361,7 +1365,7 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
}
if rowsAffected == 0 {
return fmt.Errorf("%w: %v",
invpkg.ErrInvoiceNotFound, ref.AddIndex)
ErrInvoiceNotFound, ref.AddIndex)
}
}
@ -1376,9 +1380,9 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
}
// DeleteCanceledInvoices removes all canceled invoices from the database.
func (i *InvoiceStore) DeleteCanceledInvoices(ctx context.Context) error {
var writeTxOpt InvoiceQueriesTxOptions
err := i.db.ExecTx(ctx, &writeTxOpt, func(db InvoiceQueries) error {
func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
var writeTxOpt SQLInvoiceQueriesTxOptions
err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
_, err := db.DeleteCanceledInvoices(ctx)
if err != nil {
return fmt.Errorf("unable to delete canceled "+
@ -1398,9 +1402,9 @@ func (i *InvoiceStore) DeleteCanceledInvoices(ctx context.Context) error {
// invoice is AMP and the setID is not nil, then it will also fetch the AMP
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
func fetchInvoiceData(ctx context.Context, db InvoiceQueries,
func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
*invpkg.Invoice, error) {
*Invoice, error) {
// Unmarshal the common data.
hash, invoice, err := unmarshalInvoice(row)
@ -1444,7 +1448,7 @@ func fetchInvoiceData(ctx context.Context, db InvoiceQueries,
invoice.Htlcs = htlcs
var amountPaid lnwire.MilliSatoshi
for _, htlc := range htlcs {
if htlc.State == invpkg.HtlcStateSettled {
if htlc.State == HtlcStateSettled {
amountPaid += htlc.Amt
}
}
@ -1455,7 +1459,7 @@ func fetchInvoiceData(ctx context.Context, db InvoiceQueries,
}
// getInvoiceFeatures fetches the invoice features for the given invoice id.
func getInvoiceFeatures(ctx context.Context, db InvoiceQueries,
func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
invoiceID int64) (*lnwire.FeatureVector, error) {
rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
@ -1473,8 +1477,8 @@ func getInvoiceFeatures(ctx context.Context, db InvoiceQueries,
}
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries,
invoiceID int64) (map[invpkg.CircuitKey]*invpkg.InvoiceHTLC, error) {
func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
if err != nil {
@ -1505,7 +1509,7 @@ func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries,
cr[row.HtlcID][uint64(row.Key)] = value
}
htlcs := make(map[invpkg.CircuitKey]*invpkg.InvoiceHTLC, len(htlcRows))
htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
for _, row := range htlcRows {
circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
@ -1527,7 +1531,7 @@ func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries,
}
// unmarshalInvoice converts an InvoiceRow to an Invoice.
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
error) {
var (
@ -1576,13 +1580,13 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
cltvDelta = row.CltvDelta.Int32
}
invoice := &invpkg.Invoice{
invoice := &Invoice{
SettleIndex: uint64(settleIndex),
SettleDate: settledAt,
Memo: memo,
PaymentRequest: paymentRequest,
CreationDate: row.CreatedAt.Local(),
Terms: invpkg.ContractTerm{
Terms: ContractTerm{
FinalCltvDelta: cltvDelta,
Expiry: time.Duration(row.Expiry),
PaymentPreimage: preimage,
@ -1590,10 +1594,10 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
PaymentAddr: paymentAddr,
},
AddIndex: uint64(row.ID),
State: invpkg.ContractState(row.State),
State: ContractState(row.State),
AmtPaid: lnwire.MilliSatoshi(row.AmountPaidMsat),
Htlcs: make(map[models.CircuitKey]*invpkg.InvoiceHTLC),
AMPState: invpkg.AMPInvoiceState{},
Htlcs: make(map[models.CircuitKey]*InvoiceHTLC),
AMPState: AMPInvoiceState{},
HodlInvoice: row.IsHodl,
}
@ -1601,34 +1605,34 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
}
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (invpkg.CircuitKey,
*invpkg.InvoiceHTLC, error) {
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
*InvoiceHTLC, error) {
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
if err != nil {
return invpkg.CircuitKey{}, nil, err
return CircuitKey{}, nil, err
}
chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
if row.HtlcID < 0 {
return invpkg.CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
"value: %v", row.HtlcID)
}
htlcID := uint64(row.HtlcID)
circuitKey := invpkg.CircuitKey{
circuitKey := CircuitKey{
ChanID: chanID,
HtlcID: htlcID,
}
htlc := &invpkg.InvoiceHTLC{
htlc := &InvoiceHTLC{
Amt: lnwire.MilliSatoshi(row.AmountMsat),
AcceptHeight: uint32(row.AcceptHeight),
AcceptTime: row.AcceptTime.Local(),
Expiry: uint32(row.ExpiryHeight),
State: invpkg.HtlcState(row.State),
State: HtlcState(row.State),
}
if row.TotalMppMsat.Valid {