sqldb+invoices: move SQL invoice store impl to invoices package

This commit is contained in:
Andras Banki-Horvath
2024-03-29 10:54:00 +01:00
parent 6d316ef56f
commit 7f5c8219ef
4 changed files with 169 additions and 165 deletions

View File

@ -1053,12 +1053,12 @@ func (d *DefaultDatabaseBuilder) BuildDatabase(
executor := sqldb.NewTransactionExecutor( executor := sqldb.NewTransactionExecutor(
dbs.NativeSQLStore, dbs.NativeSQLStore,
func(tx *sql.Tx) sqldb.InvoiceQueries { func(tx *sql.Tx) invoices.SQLInvoiceQueries {
return dbs.NativeSQLStore.WithTx(tx) return dbs.NativeSQLStore.WithTx(tx)
}, },
) )
dbs.InvoiceDB = sqldb.NewInvoiceStore( dbs.InvoiceDB = invoices.NewSQLStore(
executor, clock.NewDefaultClock(), executor, clock.NewDefaultClock(),
) )
} else { } else {

View File

@ -136,14 +136,14 @@ func TestInvoiceRegistry(t *testing.T) {
} }
executor := sqldb.NewTransactionExecutor( executor := sqldb.NewTransactionExecutor(
db, func(tx *sql.Tx) sqldb.InvoiceQueries { db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries {
return db.WithTx(tx) return db.WithTx(tx)
}, },
) )
testClock := clock.NewTestClock(testNow) testClock := clock.NewTestClock(testNow)
return sqldb.NewInvoiceStore(executor, testClock), testClock return invpkg.NewSQLStore(executor, testClock), testClock
} }
for _, test := range testList { for _, test := range testList {

View File

@ -240,14 +240,14 @@ func TestInvoices(t *testing.T) {
} }
executor := sqldb.NewTransactionExecutor( executor := sqldb.NewTransactionExecutor(
db, func(tx *sql.Tx) sqldb.InvoiceQueries { db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries {
return db.WithTx(tx) return db.WithTx(tx)
}, },
) )
testClock := clock.NewTestClock(testNow) testClock := clock.NewTestClock(testNow)
return sqldb.NewInvoiceStore(executor, testClock) return invpkg.NewSQLStore(executor, testClock)
} }
for _, test := range testList { for _, test := range testList {

View File

@ -1,4 +1,4 @@
package sqldb package invoices
import ( import (
"context" "context"
@ -12,10 +12,10 @@ import (
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
invpkg "github.com/lightningnetwork/lnd/invoices"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/sqlc" "github.com/lightningnetwork/lnd/sqldb/sqlc"
) )
@ -25,9 +25,9 @@ const (
queryPaginationLimit = 100 queryPaginationLimit = 100
) )
// InvoiceQueries is an interface that defines the set of operations that can be // SQLInvoiceQueries is an interface that defines the set of operations that can
// executed against the invoice database. // be executed against the invoice SQL database.
type InvoiceQueries interface { //nolint:interfacebloat type SQLInvoiceQueries interface { //nolint:interfacebloat
InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64, InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64,
error) error)
@ -117,11 +117,11 @@ type InvoiceQueries interface { //nolint:interfacebloat
arg sqlc.OnAMPSubInvoiceSettledParams) error arg sqlc.OnAMPSubInvoiceSettledParams) error
} }
var _ invpkg.InvoiceDB = (*InvoiceStore)(nil) var _ InvoiceDB = (*SQLStore)(nil)
// InvoiceQueriesTxOptions defines the set of db txn options the InvoiceQueries // SQLInvoiceQueriesTxOptions defines the set of db txn options the
// understands. // SQLInvoiceQueries understands.
type InvoiceQueriesTxOptions struct { type SQLInvoiceQueriesTxOptions struct {
// readOnly governs if a read only transaction is needed or not. // readOnly governs if a read only transaction is needed or not.
readOnly bool readOnly bool
} }
@ -129,37 +129,37 @@ type InvoiceQueriesTxOptions struct {
// ReadOnly returns true if the transaction should be read only. // ReadOnly returns true if the transaction should be read only.
// //
// NOTE: This implements the TxOptions. // NOTE: This implements the TxOptions.
func (a *InvoiceQueriesTxOptions) ReadOnly() bool { func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool {
return a.readOnly return a.readOnly
} }
// NewInvoiceQueryReadTx creates a new read transaction option set. // NewSQLInvoiceQueryReadTx creates a new read transaction option set.
func NewInvoiceQueryReadTx() InvoiceQueriesTxOptions { func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions {
return InvoiceQueriesTxOptions{ return SQLInvoiceQueriesTxOptions{
readOnly: true, readOnly: true,
} }
} }
// BatchedInvoiceQueries is a version of the InvoiceQueries that's capable of // BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable
// batched database operations. // of batched database operations.
type BatchedInvoiceQueries interface { type BatchedSQLInvoiceQueries interface {
InvoiceQueries SQLInvoiceQueries
BatchedTx[InvoiceQueries] sqldb.BatchedTx[SQLInvoiceQueries]
} }
// InvoiceStore represents a storage backend. // SQLStore represents a storage backend.
type InvoiceStore struct { type SQLStore struct {
db BatchedInvoiceQueries db BatchedSQLInvoiceQueries
clock clock.Clock clock clock.Clock
} }
// NewInvoiceStore creates a new InvoiceStore instance given a open // NewSQLStore creates a new SQLStore instance given a open
// BatchedInvoiceQueries storage backend. // BatchedSQLInvoiceQueries storage backend.
func NewInvoiceStore(db BatchedInvoiceQueries, func NewSQLStore(db BatchedSQLInvoiceQueries,
clock clock.Clock) *InvoiceStore { clock clock.Clock) *SQLStore {
return &InvoiceStore{ return &SQLStore{
db: db, db: db,
clock: clock, clock: clock,
} }
@ -171,17 +171,17 @@ func NewInvoiceStore(db BatchedInvoiceQueries,
// duplicate payment hashes. // duplicate payment hashes.
// //
// NOTE: A side effect of this function is that it sets AddIndex on newInvoice. // NOTE: A side effect of this function is that it sets AddIndex on newInvoice.
func (i *InvoiceStore) AddInvoice(ctx context.Context, func (i *SQLStore) AddInvoice(ctx context.Context,
newInvoice *invpkg.Invoice, paymentHash lntypes.Hash) (uint64, error) { newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) {
// Make sure this is a valid invoice before trying to store it in our // Make sure this is a valid invoice before trying to store it in our
// DB. // DB.
if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil { if err := ValidateInvoice(newInvoice, paymentHash); err != nil {
return 0, err return 0, err
} }
var ( var (
writeTxOpts InvoiceQueriesTxOptions writeTxOpts SQLInvoiceQueriesTxOptions
invoiceID int64 invoiceID int64
) )
@ -193,16 +193,18 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
paymentRequestHash = h.Sum(nil) paymentRequestHash = h.Sum(nil)
} }
err := i.db.ExecTx(ctx, &writeTxOpts, func(db InvoiceQueries) error { err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error {
params := sqlc.InsertInvoiceParams{ params := sqlc.InsertInvoiceParams{
Hash: paymentHash[:], Hash: paymentHash[:],
Memo: SQLStr(string(newInvoice.Memo)), Memo: sqldb.SQLStr(string(newInvoice.Memo)),
AmountMsat: int64(newInvoice.Terms.Value), AmountMsat: int64(newInvoice.Terms.Value),
// Note: BOLT12 invoices don't have a final cltv delta. // Note: BOLT12 invoices don't have a final cltv delta.
CltvDelta: SQLInt32(newInvoice.Terms.FinalCltvDelta), CltvDelta: sqldb.SQLInt32(
Expiry: int32(newInvoice.Terms.Expiry), newInvoice.Terms.FinalCltvDelta,
),
Expiry: int32(newInvoice.Terms.Expiry),
// Note: keysend invoices don't have a payment request. // Note: keysend invoices don't have a payment request.
PaymentRequest: SQLStr(string( PaymentRequest: sqldb.SQLStr(string(
newInvoice.PaymentRequest), newInvoice.PaymentRequest),
), ),
PaymentRequestHash: paymentRequestHash, PaymentRequestHash: paymentRequestHash,
@ -218,7 +220,7 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
// HODL invoices. // HODL invoices.
if newInvoice.Terms.PaymentPreimage != nil { if newInvoice.Terms.PaymentPreimage != nil {
preimage := *newInvoice.Terms.PaymentPreimage preimage := *newInvoice.Terms.PaymentPreimage
if preimage == invpkg.UnknownPreimage { if preimage == UnknownPreimage {
return errors.New("cannot use all-zeroes " + return errors.New("cannot use all-zeroes " +
"preimage") "preimage")
} }
@ -226,7 +228,7 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
} }
// Some non MPP payments may have the default (invalid) value. // Some non MPP payments may have the default (invalid) value.
if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr { if newInvoice.Terms.PaymentAddr != BlankPayAddr {
params.PaymentAddr = newInvoice.Terms.PaymentAddr[:] params.PaymentAddr = newInvoice.Terms.PaymentAddr[:]
} }
@ -259,11 +261,11 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
}) })
}) })
if err != nil { if err != nil {
mappedSQLErr := MapSQLError(err) mappedSQLErr := sqldb.MapSQLError(err)
var uniqueConstraintErr *ErrSQLUniqueConstraintViolation var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation
if errors.As(mappedSQLErr, &uniqueConstraintErr) { if errors.As(mappedSQLErr, &uniqueConstraintErr) {
// Add context to unique constraint errors. // Add context to unique constraint errors.
return 0, invpkg.ErrDuplicateInvoice return 0, ErrDuplicateInvoice
} }
return 0, fmt.Errorf("unable to add invoice(%v): %w", return 0, fmt.Errorf("unable to add invoice(%v): %w",
@ -277,15 +279,15 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context,
// fetchInvoice fetches the common invoice data and the AMP state for the // fetchInvoice fetches the common invoice data and the AMP state for the
// invoice with the given reference. // invoice with the given reference.
func (i *InvoiceStore) fetchInvoice(ctx context.Context, func (i *SQLStore) fetchInvoice(ctx context.Context,
db InvoiceQueries, ref invpkg.InvoiceRef) (*invpkg.Invoice, error) { db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) {
if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil { if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil {
return nil, invpkg.ErrInvoiceNotFound return nil, ErrInvoiceNotFound
} }
var ( var (
invoice *invpkg.Invoice invoice *Invoice
params sqlc.GetInvoiceParams params sqlc.GetInvoiceParams
) )
@ -300,7 +302,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
// all. Only allow lookups for payment address if it is not a blank // all. Only allow lookups for payment address if it is not a blank
// payment address, which is a special-cased value for legacy keysend // payment address, which is a special-cased value for legacy keysend
// invoices. // invoices.
if ref.PayAddr() != nil && *ref.PayAddr() != invpkg.BlankPayAddr { if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr {
params.PaymentAddr = ref.PayAddr()[:] params.PaymentAddr = ref.PayAddr()[:]
} }
@ -313,7 +315,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
rows, err := db.GetInvoice(ctx, params) rows, err := db.GetInvoice(ctx, params)
switch { switch {
case len(rows) == 0: case len(rows) == 0:
return nil, invpkg.ErrInvoiceNotFound return nil, ErrInvoiceNotFound
case len(rows) > 1: case len(rows) > 1:
// In case the reference is ambiguous, meaning it matches more // In case the reference is ambiguous, meaning it matches more
@ -333,12 +335,12 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
// Now that we got the invoice itself, fetch the HTLCs as requested by // Now that we got the invoice itself, fetch the HTLCs as requested by
// the modifier. // the modifier.
switch ref.Modifier() { switch ref.Modifier() {
case invpkg.DefaultModifier: case DefaultModifier:
// By default we'll fetch all AMP HTLCs. // By default we'll fetch all AMP HTLCs.
setID = nil setID = nil
fetchAmpHtlcs = true fetchAmpHtlcs = true
case invpkg.HtlcSetOnlyModifier: case HtlcSetOnlyModifier:
// In this case we'll fetch all AMP HTLCs for the // In this case we'll fetch all AMP HTLCs for the
// specified set id. // specified set id.
if ref.SetID() == nil { if ref.SetID() == nil {
@ -349,7 +351,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
setID = ref.SetID() setID = ref.SetID()
fetchAmpHtlcs = true fetchAmpHtlcs = true
case invpkg.HtlcSetBlankModifier: case HtlcSetBlankModifier:
// No need to fetch any HTLCs. // No need to fetch any HTLCs.
setID = nil setID = nil
fetchAmpHtlcs = false fetchAmpHtlcs = false
@ -377,9 +379,9 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context,
// well. // well.
// //
//nolint:funlen //nolint:funlen
func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64, func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64,
setID *[32]byte, fetchHtlcs bool) (invpkg.AMPInvoiceState, setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState,
invpkg.HTLCSet, error) { HTLCSet, error) {
var paramSetID []byte var paramSetID []byte
if setID != nil { if setID != nil {
@ -398,7 +400,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
return nil, nil, err return nil, nil, err
} }
ampState := make(map[invpkg.SetID]invpkg.InvoiceStateAMP) ampState := make(map[SetID]InvoiceStateAMP)
for _, row := range ampInvoiceRows { for _, row := range ampInvoiceRows {
var rowSetID [32]byte var rowSetID [32]byte
@ -413,8 +415,8 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
} }
copy(rowSetID[:], row.SetID) copy(rowSetID[:], row.SetID)
ampState[rowSetID] = invpkg.InvoiceStateAMP{ ampState[rowSetID] = InvoiceStateAMP{
State: invpkg.HtlcState(row.State), State: HtlcState(row.State),
SettleIndex: uint64(row.SettleIndex.Int64), SettleIndex: uint64(row.SettleIndex.Int64),
SettleDate: settleDate, SettleDate: settleDate,
InvoiceKeys: make(map[models.CircuitKey]struct{}), InvoiceKeys: make(map[models.CircuitKey]struct{}),
@ -457,7 +459,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
return nil, nil, err return nil, nil, err
} }
ampHtlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC) ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC)
for _, row := range ampHtlcRows { for _, row := range ampHtlcRows {
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64) uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
if err != nil { if err != nil {
@ -473,17 +475,17 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
htlcID := uint64(row.HtlcID) htlcID := uint64(row.HtlcID)
circuitKey := invpkg.CircuitKey{ circuitKey := CircuitKey{
ChanID: chanID, ChanID: chanID,
HtlcID: htlcID, HtlcID: htlcID,
} }
htlc := &invpkg.InvoiceHTLC{ htlc := &InvoiceHTLC{
Amt: lnwire.MilliSatoshi(row.AmountMsat), Amt: lnwire.MilliSatoshi(row.AmountMsat),
AcceptHeight: uint32(row.AcceptHeight), AcceptHeight: uint32(row.AcceptHeight),
AcceptTime: row.AcceptTime.Local(), AcceptTime: row.AcceptTime.Local(),
Expiry: uint32(row.ExpiryHeight), Expiry: uint32(row.ExpiryHeight),
State: invpkg.HtlcState(row.State), State: HtlcState(row.State),
} }
if row.TotalMppMsat.Valid { if row.TotalMppMsat.Valid {
@ -522,7 +524,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
rootShare, setID, uint32(row.ChildIndex), rootShare, setID, uint32(row.ChildIndex),
) )
htlc.AMP = &invpkg.InvoiceHtlcAMPData{ htlc.AMP = &InvoiceHtlcAMPData{
Record: *ampRecord, Record: *ampRecord,
} }
@ -564,7 +566,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
invoiceKeys[key] = struct{}{} invoiceKeys[key] = struct{}{}
if htlc.State != invpkg.HtlcStateCanceled { //nolint: lll if htlc.State != HtlcStateCanceled { //nolint: lll
amtPaid += htlc.Amt amtPaid += htlc.Amt
} }
} }
@ -584,22 +586,22 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64,
// ID for an AMP sub invoice. If the invoice is found, we'll return the complete // ID for an AMP sub invoice. If the invoice is found, we'll return the complete
// invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound // invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound
// error. // error.
func (i *InvoiceStore) LookupInvoice(ctx context.Context, func (i *SQLStore) LookupInvoice(ctx context.Context,
ref invpkg.InvoiceRef) (invpkg.Invoice, error) { ref InvoiceRef) (Invoice, error) {
var ( var (
invoice *invpkg.Invoice invoice *Invoice
err error err error
) )
readTxOpt := NewInvoiceQueryReadTx() readTxOpt := NewSQLInvoiceQueryReadTx()
txErr := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoice, err = i.fetchInvoice(ctx, db, ref) invoice, err = i.fetchInvoice(ctx, db, ref)
return err return err
}) })
if txErr != nil { if txErr != nil {
return invpkg.Invoice{}, txErr return Invoice{}, txErr
} }
return *invoice, nil return *invoice, nil
@ -608,14 +610,14 @@ func (i *InvoiceStore) LookupInvoice(ctx context.Context,
// FetchPendingInvoices returns all the invoices that are currently in a // FetchPendingInvoices returns all the invoices that are currently in a
// "pending" state. An invoice is pending if it has been created but not yet // "pending" state. An invoice is pending if it has been created but not yet
// settled or canceled. // settled or canceled.
func (i *InvoiceStore) FetchPendingInvoices(ctx context.Context) ( func (i *SQLStore) FetchPendingInvoices(ctx context.Context) (
map[lntypes.Hash]invpkg.Invoice, error) { map[lntypes.Hash]Invoice, error) {
var invoices map[lntypes.Hash]invpkg.Invoice var invoices map[lntypes.Hash]Invoice
readTxOpt := NewInvoiceQueryReadTx() readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoices = make(map[lntypes.Hash]invpkg.Invoice) invoices = make(map[lntypes.Hash]Invoice)
limit := queryPaginationLimit limit := queryPaginationLimit
return queryWithLimit(func(offset int) (int, error) { return queryWithLimit(func(offset int) (int, error) {
@ -661,24 +663,24 @@ func (i *InvoiceStore) FetchPendingInvoices(ctx context.Context) (
// //
// NOTE: The index starts from 1. As a result we enforce that specifying a value // NOTE: The index starts from 1. As a result we enforce that specifying a value
// below the starting index value is a noop. // below the starting index value is a noop.
func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) ( func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
[]invpkg.Invoice, error) { []Invoice, error) {
var invoices []invpkg.Invoice var invoices []Invoice
if idx == 0 { if idx == 0 {
return invoices, nil return invoices, nil
} }
readTxOpt := NewInvoiceQueryReadTx() readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoices = nil invoices = nil
settleIdx := idx settleIdx := idx
limit := queryPaginationLimit limit := queryPaginationLimit
err := queryWithLimit(func(offset int) (int, error) { err := queryWithLimit(func(offset int) (int, error) {
params := sqlc.FilterInvoicesParams{ params := sqlc.FilterInvoicesParams{
SettleIndexGet: SQLInt64(settleIdx + 1), SettleIndexGet: sqldb.SQLInt64(settleIdx + 1),
NumLimit: int32(limit), NumLimit: int32(limit),
NumOffset: int32(offset), NumOffset: int32(offset),
} }
@ -715,7 +717,7 @@ func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
// the provided index. // the provided index.
ampInvoices, err := i.db.FetchSettledAMPSubInvoices( ampInvoices, err := i.db.FetchSettledAMPSubInvoices(
ctx, sqlc.FetchSettledAMPSubInvoicesParams{ ctx, sqlc.FetchSettledAMPSubInvoicesParams{
SettleIndexGet: SQLInt64(idx + 1), SettleIndexGet: sqldb.SQLInt64(idx + 1),
}, },
) )
if err != nil { if err != nil {
@ -775,24 +777,24 @@ func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) (
// //
// NOTE: The index starts from 1. As a result we enforce that specifying a value // NOTE: The index starts from 1. As a result we enforce that specifying a value
// below the starting index value is a noop. // below the starting index value is a noop.
func (i *InvoiceStore) InvoicesAddedSince(ctx context.Context, idx uint64) ( func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
[]invpkg.Invoice, error) { []Invoice, error) {
var result []invpkg.Invoice var result []Invoice
if idx == 0 { if idx == 0 {
return result, nil return result, nil
} }
readTxOpt := NewInvoiceQueryReadTx() readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
result = nil result = nil
addIdx := idx addIdx := idx
limit := queryPaginationLimit limit := queryPaginationLimit
return queryWithLimit(func(offset int) (int, error) { return queryWithLimit(func(offset int) (int, error) {
params := sqlc.FilterInvoicesParams{ params := sqlc.FilterInvoicesParams{
AddIndexGet: SQLInt64(addIdx + 1), AddIndexGet: sqldb.SQLInt64(addIdx + 1),
NumLimit: int32(limit), NumLimit: int32(limit),
NumOffset: int32(offset), NumOffset: int32(offset),
} }
@ -831,18 +833,18 @@ func (i *InvoiceStore) InvoicesAddedSince(ctx context.Context, idx uint64) (
// QueryInvoices allows a caller to query the invoice database for invoices // QueryInvoices allows a caller to query the invoice database for invoices
// within the specified add index range. // within the specified add index range.
func (i *InvoiceStore) QueryInvoices(ctx context.Context, func (i *SQLStore) QueryInvoices(ctx context.Context,
q invpkg.InvoiceQuery) (invpkg.InvoiceSlice, error) { q InvoiceQuery) (InvoiceSlice, error) {
var invoices []invpkg.Invoice var invoices []Invoice
if q.NumMaxInvoices == 0 { if q.NumMaxInvoices == 0 {
return invpkg.InvoiceSlice{}, fmt.Errorf("max invoices must " + return InvoiceSlice{}, fmt.Errorf("max invoices must " +
"be non-zero") "be non-zero")
} }
readTxOpt := NewInvoiceQueryReadTx() readTxOpt := NewSQLInvoiceQueryReadTx()
err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error {
invoices = nil invoices = nil
limit := queryPaginationLimit limit := queryPaginationLimit
@ -856,7 +858,7 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
if !q.Reversed { if !q.Reversed {
// The invoice with index offset id must not be // The invoice with index offset id must not be
// included in the results. // included in the results.
params.AddIndexGet = SQLInt64( params.AddIndexGet = sqldb.SQLInt64(
q.IndexOffset + uint64(offset) + 1, q.IndexOffset + uint64(offset) + 1,
) )
} }
@ -867,13 +869,13 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
// If the index offset was not set, we want to // If the index offset was not set, we want to
// fetch from the lastest invoice. // fetch from the lastest invoice.
if idx == 0 { if idx == 0 {
params.AddIndexLet = SQLInt64( params.AddIndexLet = sqldb.SQLInt64(
int64(math.MaxInt64), int64(math.MaxInt64),
) )
} else { } else {
// The invoice with index offset id must // The invoice with index offset id must
// not be included in the results. // not be included in the results.
params.AddIndexLet = SQLInt64( params.AddIndexLet = sqldb.SQLInt64(
idx - int32(offset) - 1, idx - int32(offset) - 1,
) )
} }
@ -882,13 +884,13 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
} }
if q.CreationDateStart != 0 { if q.CreationDateStart != 0 {
params.CreatedAfter = SQLTime( params.CreatedAfter = sqldb.SQLTime(
time.Unix(q.CreationDateStart, 0).UTC(), time.Unix(q.CreationDateStart, 0).UTC(),
) )
} }
if q.CreationDateEnd != 0 { if q.CreationDateEnd != 0 {
params.CreatedBefore = SQLTime( params.CreatedBefore = sqldb.SQLTime(
time.Unix(q.CreationDateEnd, 0).UTC(), time.Unix(q.CreationDateEnd, 0).UTC(),
) )
} }
@ -919,12 +921,12 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
}, limit) }, limit)
}) })
if err != nil { if err != nil {
return invpkg.InvoiceSlice{}, fmt.Errorf("unable to query "+ return InvoiceSlice{}, fmt.Errorf("unable to query "+
"invoices: %w", err) "invoices: %w", err)
} }
if len(invoices) == 0 { if len(invoices) == 0 {
return invpkg.InvoiceSlice{ return InvoiceSlice{
InvoiceQuery: q, InvoiceQuery: q,
}, nil }, nil
} }
@ -941,7 +943,7 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
} }
} }
res := invpkg.InvoiceSlice{ res := InvoiceSlice{
InvoiceQuery: q, InvoiceQuery: q,
Invoices: invoices, Invoices: invoices,
FirstIndexOffset: invoices[0].AddIndex, FirstIndexOffset: invoices[0].AddIndex,
@ -954,15 +956,15 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context,
// sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using // sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using
// a SQL database as the backend. // a SQL database as the backend.
type sqlInvoiceUpdater struct { type sqlInvoiceUpdater struct {
db InvoiceQueries db SQLInvoiceQueries
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
invoice *invpkg.Invoice invoice *Invoice
updateTime time.Time updateTime time.Time
} }
// AddHtlc adds a new htlc to the invoice. // AddHtlc adds a new htlc to the invoice.
func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey, func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
newHtlc *invpkg.InvoiceHTLC) error { newHtlc *InvoiceHTLC) error {
htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC( htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC(
s.ctx, sqlc.InsertInvoiceHTLCParams{ s.ctx, sqlc.InsertInvoiceHTLCParams{
@ -1012,10 +1014,10 @@ func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
}, },
) )
if err != nil { if err != nil {
mappedSQLErr := MapSQLError(err) mappedSQLErr := sqldb.MapSQLError(err)
var uniqueConstraintErr *ErrSQLUniqueConstraintViolation var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:lll
if errors.As(mappedSQLErr, &uniqueConstraintErr) { if errors.As(mappedSQLErr, &uniqueConstraintErr) {
return invpkg.ErrDuplicateSetID{ return ErrDuplicateSetID{
SetID: setID, SetID: setID,
} }
} }
@ -1072,7 +1074,7 @@ func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey,
// ResolveHtlc marks an htlc as resolved with the given state. // ResolveHtlc marks an htlc as resolved with the given state.
func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey, func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
state invpkg.HtlcState, resolveTime time.Time) error { state HtlcState, resolveTime time.Time) error {
return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{ return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{
HtlcID: int64(circuitKey.HtlcID), HtlcID: int64(circuitKey.HtlcID),
@ -1081,7 +1083,7 @@ func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey,
), ),
InvoiceID: int64(s.invoice.AddIndex), InvoiceID: int64(s.invoice.AddIndex),
State: int16(state), State: int16(state),
ResolveTime: SQLTime(resolveTime.UTC()), ResolveTime: sqldb.SQLTime(resolveTime.UTC()),
}) })
} }
@ -1107,7 +1109,7 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
return err return err
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return invpkg.ErrInvoiceNotFound return ErrInvoiceNotFound
} }
return nil return nil
@ -1115,7 +1117,7 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte,
// UpdateInvoiceState updates the invoice state to the new state. // UpdateInvoiceState updates the invoice state to the new state.
func (s *sqlInvoiceUpdater) UpdateInvoiceState( func (s *sqlInvoiceUpdater) UpdateInvoiceState(
newState invpkg.ContractState, preimage *lntypes.Preimage) error { newState ContractState, preimage *lntypes.Preimage) error {
var ( var (
settleIndex sql.NullInt64 settleIndex sql.NullInt64
@ -1123,16 +1125,16 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState(
) )
switch newState { switch newState {
case invpkg.ContractSettled: case ContractSettled:
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx) nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
if err != nil { if err != nil {
return err return err
} }
settleIndex = SQLInt64(nextSettleIndex) settleIndex = sqldb.SQLInt64(nextSettleIndex)
// If the invoice is settled, we'll also update the settle time. // If the invoice is settled, we'll also update the settle time.
settledAt = SQLTime(s.updateTime.UTC()) settledAt = sqldb.SQLTime(s.updateTime.UTC())
err = s.db.OnInvoiceSettled( err = s.db.OnInvoiceSettled(
s.ctx, sqlc.OnInvoiceSettledParams{ s.ctx, sqlc.OnInvoiceSettledParams{
@ -1144,7 +1146,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState(
return err return err
} }
case invpkg.ContractCanceled: case ContractCanceled:
err := s.db.OnInvoiceCanceled( err := s.db.OnInvoiceCanceled(
s.ctx, sqlc.OnInvoiceCanceledParams{ s.ctx, sqlc.OnInvoiceCanceledParams{
AddedAt: s.updateTime.UTC(), AddedAt: s.updateTime.UTC(),
@ -1177,7 +1179,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState(
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return invpkg.ErrInvoiceNotFound return ErrInvoiceNotFound
} }
if settleIndex.Valid { if settleIndex.Valid {
@ -1205,7 +1207,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid(
// UpdateAmpState updates the state of the AMP sub invoice identified by the // UpdateAmpState updates the state of the AMP sub invoice identified by the
// setID. // setID.
func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte, func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
newState invpkg.InvoiceStateAMP, _ models.CircuitKey) error { newState InvoiceStateAMP, _ models.CircuitKey) error {
var ( var (
settleIndex sql.NullInt64 settleIndex sql.NullInt64
@ -1213,16 +1215,16 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
) )
switch newState.State { switch newState.State {
case invpkg.HtlcStateSettled: case HtlcStateSettled:
nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx) nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx)
if err != nil { if err != nil {
return err return err
} }
settleIndex = SQLInt64(nextSettleIndex) settleIndex = sqldb.SQLInt64(nextSettleIndex)
// If the invoice is settled, we'll also update the settle time. // If the invoice is settled, we'll also update the settle time.
settledAt = SQLTime(s.updateTime.UTC()) settledAt = sqldb.SQLTime(s.updateTime.UTC())
err = s.db.OnAMPSubInvoiceSettled( err = s.db.OnAMPSubInvoiceSettled(
s.ctx, sqlc.OnAMPSubInvoiceSettledParams{ s.ctx, sqlc.OnAMPSubInvoiceSettledParams{
@ -1235,7 +1237,7 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
return err return err
} }
case invpkg.HtlcStateCanceled: case HtlcStateCanceled:
err := s.db.OnAMPSubInvoiceCanceled( err := s.db.OnAMPSubInvoiceCanceled(
s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{ s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{
AddedAt: s.updateTime.UTC(), AddedAt: s.updateTime.UTC(),
@ -1266,7 +1268,7 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte,
// Finalize finalizes the update before it is written to the database. Note that // Finalize finalizes the update before it is written to the database. Note that
// we don't use this directly in the SQL implementation, so the function is just // we don't use this directly in the SQL implementation, so the function is just
// a stub. // a stub.
func (s *sqlInvoiceUpdater) Finalize(_ invpkg.UpdateType) error { func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error {
return nil return nil
} }
@ -1277,14 +1279,14 @@ func (s *sqlInvoiceUpdater) Finalize(_ invpkg.UpdateType) error {
// The update is performed inside the same database transaction that fetches the // The update is performed inside the same database transaction that fetches the
// invoice and is therefore atomic. The fields to update are controlled by the // invoice and is therefore atomic. The fields to update are controlled by the
// supplied callback. // supplied callback.
func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef, func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef,
_ *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) ( _ *SetID, callback InvoiceUpdateCallback) (
*invpkg.Invoice, error) { *Invoice, error) {
var updatedInvoice *invpkg.Invoice var updatedInvoice *Invoice
txOpt := InvoiceQueriesTxOptions{readOnly: false} txOpt := SQLInvoiceQueriesTxOptions{readOnly: false}
txErr := i.db.ExecTx(ctx, &txOpt, func(db InvoiceQueries) error { txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error {
invoice, err := i.fetchInvoice(ctx, db, ref) invoice, err := i.fetchInvoice(ctx, db, ref)
if err != nil { if err != nil {
return err return err
@ -1299,7 +1301,7 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
} }
payHash := ref.PayHash() payHash := ref.PayHash()
updatedInvoice, err = invpkg.UpdateInvoice( updatedInvoice, err = UpdateInvoice(
payHash, invoice, updateTime, callback, updater, payHash, invoice, updateTime, callback, updater,
) )
@ -1308,7 +1310,7 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
if txErr != nil { if txErr != nil {
// If the invoice is already settled, we'll return the // If the invoice is already settled, we'll return the
// (unchanged) invoice and the ErrInvoiceAlreadySettled error. // (unchanged) invoice and the ErrInvoiceAlreadySettled error.
if errors.Is(txErr, invpkg.ErrInvoiceAlreadySettled) { if errors.Is(txErr, ErrInvoiceAlreadySettled) {
return updatedInvoice, txErr return updatedInvoice, txErr
} }
@ -1320,8 +1322,8 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef,
// DeleteInvoice attempts to delete the passed invoices and all their related // DeleteInvoice attempts to delete the passed invoices and all their related
// data from the database in one transaction. // data from the database in one transaction.
func (i *InvoiceStore) DeleteInvoice(ctx context.Context, func (i *SQLStore) DeleteInvoice(ctx context.Context,
invoicesToDelete []invpkg.InvoiceDeleteRef) error { invoicesToDelete []InvoiceDeleteRef) error {
// All the InvoiceDeleteRef instances include the add index of the // All the InvoiceDeleteRef instances include the add index of the
// invoice. The rest was added to ensure that the invoices were deleted // invoice. The rest was added to ensure that the invoices were deleted
@ -1334,15 +1336,17 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
} }
} }
var writeTxOpt InvoiceQueriesTxOptions var writeTxOpt SQLInvoiceQueriesTxOptions
err := i.db.ExecTx(ctx, &writeTxOpt, func(db InvoiceQueries) error { err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
for _, ref := range invoicesToDelete { for _, ref := range invoicesToDelete {
params := sqlc.DeleteInvoiceParams{ params := sqlc.DeleteInvoiceParams{
AddIndex: SQLInt64(ref.AddIndex), AddIndex: sqldb.SQLInt64(ref.AddIndex),
} }
if ref.SettleIndex != 0 { if ref.SettleIndex != 0 {
params.SettleIndex = SQLInt64(ref.SettleIndex) params.SettleIndex = sqldb.SQLInt64(
ref.SettleIndex,
)
} }
if ref.PayHash != lntypes.ZeroHash { if ref.PayHash != lntypes.ZeroHash {
@ -1361,7 +1365,7 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return fmt.Errorf("%w: %v", return fmt.Errorf("%w: %v",
invpkg.ErrInvoiceNotFound, ref.AddIndex) ErrInvoiceNotFound, ref.AddIndex)
} }
} }
@ -1376,9 +1380,9 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context,
} }
// DeleteCanceledInvoices removes all canceled invoices from the database. // DeleteCanceledInvoices removes all canceled invoices from the database.
func (i *InvoiceStore) DeleteCanceledInvoices(ctx context.Context) error { func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error {
var writeTxOpt InvoiceQueriesTxOptions var writeTxOpt SQLInvoiceQueriesTxOptions
err := i.db.ExecTx(ctx, &writeTxOpt, func(db InvoiceQueries) error { err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error {
_, err := db.DeleteCanceledInvoices(ctx) _, err := db.DeleteCanceledInvoices(ctx)
if err != nil { if err != nil {
return fmt.Errorf("unable to delete canceled "+ return fmt.Errorf("unable to delete canceled "+
@ -1398,9 +1402,9 @@ func (i *InvoiceStore) DeleteCanceledInvoices(ctx context.Context) error {
// invoice is AMP and the setID is not nil, then it will also fetch the AMP // invoice is AMP and the setID is not nil, then it will also fetch the AMP
// state and HTLCs for the given setID, otherwise for all AMP sub invoices of // state and HTLCs for the given setID, otherwise for all AMP sub invoices of
// the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs. // the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs.
func fetchInvoiceData(ctx context.Context, db InvoiceQueries, func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries,
row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash, row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash,
*invpkg.Invoice, error) { *Invoice, error) {
// Unmarshal the common data. // Unmarshal the common data.
hash, invoice, err := unmarshalInvoice(row) hash, invoice, err := unmarshalInvoice(row)
@ -1444,7 +1448,7 @@ func fetchInvoiceData(ctx context.Context, db InvoiceQueries,
invoice.Htlcs = htlcs invoice.Htlcs = htlcs
var amountPaid lnwire.MilliSatoshi var amountPaid lnwire.MilliSatoshi
for _, htlc := range htlcs { for _, htlc := range htlcs {
if htlc.State == invpkg.HtlcStateSettled { if htlc.State == HtlcStateSettled {
amountPaid += htlc.Amt amountPaid += htlc.Amt
} }
} }
@ -1455,7 +1459,7 @@ func fetchInvoiceData(ctx context.Context, db InvoiceQueries,
} }
// getInvoiceFeatures fetches the invoice features for the given invoice id. // getInvoiceFeatures fetches the invoice features for the given invoice id.
func getInvoiceFeatures(ctx context.Context, db InvoiceQueries, func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries,
invoiceID int64) (*lnwire.FeatureVector, error) { invoiceID int64) (*lnwire.FeatureVector, error) {
rows, err := db.GetInvoiceFeatures(ctx, invoiceID) rows, err := db.GetInvoiceFeatures(ctx, invoiceID)
@ -1473,8 +1477,8 @@ func getInvoiceFeatures(ctx context.Context, db InvoiceQueries,
} }
// getInvoiceHtlcs fetches the invoice htlcs for the given invoice id. // getInvoiceHtlcs fetches the invoice htlcs for the given invoice id.
func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries, func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries,
invoiceID int64) (map[invpkg.CircuitKey]*invpkg.InvoiceHTLC, error) { invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) {
htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID) htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID)
if err != nil { if err != nil {
@ -1505,7 +1509,7 @@ func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries,
cr[row.HtlcID][uint64(row.Key)] = value cr[row.HtlcID][uint64(row.Key)] = value
} }
htlcs := make(map[invpkg.CircuitKey]*invpkg.InvoiceHTLC, len(htlcRows)) htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows))
for _, row := range htlcRows { for _, row := range htlcRows {
circuiteKey, htlc, err := unmarshalInvoiceHTLC(row) circuiteKey, htlc, err := unmarshalInvoiceHTLC(row)
@ -1527,7 +1531,7 @@ func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries,
} }
// unmarshalInvoice converts an InvoiceRow to an Invoice. // unmarshalInvoice converts an InvoiceRow to an Invoice.
func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice, func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice,
error) { error) {
var ( var (
@ -1576,13 +1580,13 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
cltvDelta = row.CltvDelta.Int32 cltvDelta = row.CltvDelta.Int32
} }
invoice := &invpkg.Invoice{ invoice := &Invoice{
SettleIndex: uint64(settleIndex), SettleIndex: uint64(settleIndex),
SettleDate: settledAt, SettleDate: settledAt,
Memo: memo, Memo: memo,
PaymentRequest: paymentRequest, PaymentRequest: paymentRequest,
CreationDate: row.CreatedAt.Local(), CreationDate: row.CreatedAt.Local(),
Terms: invpkg.ContractTerm{ Terms: ContractTerm{
FinalCltvDelta: cltvDelta, FinalCltvDelta: cltvDelta,
Expiry: time.Duration(row.Expiry), Expiry: time.Duration(row.Expiry),
PaymentPreimage: preimage, PaymentPreimage: preimage,
@ -1590,10 +1594,10 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
PaymentAddr: paymentAddr, PaymentAddr: paymentAddr,
}, },
AddIndex: uint64(row.ID), AddIndex: uint64(row.ID),
State: invpkg.ContractState(row.State), State: ContractState(row.State),
AmtPaid: lnwire.MilliSatoshi(row.AmountPaidMsat), AmtPaid: lnwire.MilliSatoshi(row.AmountPaidMsat),
Htlcs: make(map[models.CircuitKey]*invpkg.InvoiceHTLC), Htlcs: make(map[models.CircuitKey]*InvoiceHTLC),
AMPState: invpkg.AMPInvoiceState{}, AMPState: AMPInvoiceState{},
HodlInvoice: row.IsHodl, HodlInvoice: row.IsHodl,
} }
@ -1601,34 +1605,34 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice,
} }
// unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC. // unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC.
func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (invpkg.CircuitKey, func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey,
*invpkg.InvoiceHTLC, error) { *InvoiceHTLC, error) {
uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64) uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64)
if err != nil { if err != nil {
return invpkg.CircuitKey{}, nil, err return CircuitKey{}, nil, err
} }
chanID := lnwire.NewShortChanIDFromInt(uint64ChanID) chanID := lnwire.NewShortChanIDFromInt(uint64ChanID)
if row.HtlcID < 0 { if row.HtlcID < 0 {
return invpkg.CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+ return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+
"value: %v", row.HtlcID) "value: %v", row.HtlcID)
} }
htlcID := uint64(row.HtlcID) htlcID := uint64(row.HtlcID)
circuitKey := invpkg.CircuitKey{ circuitKey := CircuitKey{
ChanID: chanID, ChanID: chanID,
HtlcID: htlcID, HtlcID: htlcID,
} }
htlc := &invpkg.InvoiceHTLC{ htlc := &InvoiceHTLC{
Amt: lnwire.MilliSatoshi(row.AmountMsat), Amt: lnwire.MilliSatoshi(row.AmountMsat),
AcceptHeight: uint32(row.AcceptHeight), AcceptHeight: uint32(row.AcceptHeight),
AcceptTime: row.AcceptTime.Local(), AcceptTime: row.AcceptTime.Local(),
Expiry: uint32(row.ExpiryHeight), Expiry: uint32(row.ExpiryHeight),
State: invpkg.HtlcState(row.State), State: HtlcState(row.State),
} }
if row.TotalMppMsat.Valid { if row.TotalMppMsat.Valid {