diff --git a/config_builder.go b/config_builder.go index 7c4599fb8..002ae288a 100644 --- a/config_builder.go +++ b/config_builder.go @@ -1053,12 +1053,12 @@ func (d *DefaultDatabaseBuilder) BuildDatabase( executor := sqldb.NewTransactionExecutor( dbs.NativeSQLStore, - func(tx *sql.Tx) sqldb.InvoiceQueries { + func(tx *sql.Tx) invoices.SQLInvoiceQueries { return dbs.NativeSQLStore.WithTx(tx) }, ) - dbs.InvoiceDB = sqldb.NewInvoiceStore( + dbs.InvoiceDB = invoices.NewSQLStore( executor, clock.NewDefaultClock(), ) } else { diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 99ad6cddd..b0c019522 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -136,14 +136,14 @@ func TestInvoiceRegistry(t *testing.T) { } executor := sqldb.NewTransactionExecutor( - db, func(tx *sql.Tx) sqldb.InvoiceQueries { + db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries { return db.WithTx(tx) }, ) testClock := clock.NewTestClock(testNow) - return sqldb.NewInvoiceStore(executor, testClock), testClock + return invpkg.NewSQLStore(executor, testClock), testClock } for _, test := range testList { diff --git a/invoices/invoices_test.go b/invoices/invoices_test.go index 7e1baadd3..6f3cbd8f2 100644 --- a/invoices/invoices_test.go +++ b/invoices/invoices_test.go @@ -240,14 +240,14 @@ func TestInvoices(t *testing.T) { } executor := sqldb.NewTransactionExecutor( - db, func(tx *sql.Tx) sqldb.InvoiceQueries { + db, func(tx *sql.Tx) invpkg.SQLInvoiceQueries { return db.WithTx(tx) }, ) testClock := clock.NewTestClock(testNow) - return sqldb.NewInvoiceStore(executor, testClock) + return invpkg.NewSQLStore(executor, testClock) } for _, test := range testList { diff --git a/sqldb/invoices.go b/invoices/sql_store.go similarity index 82% rename from sqldb/invoices.go rename to invoices/sql_store.go index bc1974cdc..94575ede0 100644 --- a/sqldb/invoices.go +++ b/invoices/sql_store.go @@ -1,4 +1,4 @@ -package sqldb +package invoices import ( "context" @@ -12,10 +12,10 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" - invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/sqldb" "github.com/lightningnetwork/lnd/sqldb/sqlc" ) @@ -25,9 +25,9 @@ const ( queryPaginationLimit = 100 ) -// InvoiceQueries is an interface that defines the set of operations that can be -// executed against the invoice database. -type InvoiceQueries interface { //nolint:interfacebloat +// SQLInvoiceQueries is an interface that defines the set of operations that can +// be executed against the invoice SQL database. +type SQLInvoiceQueries interface { //nolint:interfacebloat InsertInvoice(ctx context.Context, arg sqlc.InsertInvoiceParams) (int64, error) @@ -117,11 +117,11 @@ type InvoiceQueries interface { //nolint:interfacebloat arg sqlc.OnAMPSubInvoiceSettledParams) error } -var _ invpkg.InvoiceDB = (*InvoiceStore)(nil) +var _ InvoiceDB = (*SQLStore)(nil) -// InvoiceQueriesTxOptions defines the set of db txn options the InvoiceQueries -// understands. -type InvoiceQueriesTxOptions struct { +// SQLInvoiceQueriesTxOptions defines the set of db txn options the +// SQLInvoiceQueries understands. +type SQLInvoiceQueriesTxOptions struct { // readOnly governs if a read only transaction is needed or not. readOnly bool } @@ -129,37 +129,37 @@ type InvoiceQueriesTxOptions struct { // ReadOnly returns true if the transaction should be read only. // // NOTE: This implements the TxOptions. -func (a *InvoiceQueriesTxOptions) ReadOnly() bool { +func (a *SQLInvoiceQueriesTxOptions) ReadOnly() bool { return a.readOnly } -// NewInvoiceQueryReadTx creates a new read transaction option set. -func NewInvoiceQueryReadTx() InvoiceQueriesTxOptions { - return InvoiceQueriesTxOptions{ +// NewSQLInvoiceQueryReadTx creates a new read transaction option set. +func NewSQLInvoiceQueryReadTx() SQLInvoiceQueriesTxOptions { + return SQLInvoiceQueriesTxOptions{ readOnly: true, } } -// BatchedInvoiceQueries is a version of the InvoiceQueries that's capable of -// batched database operations. -type BatchedInvoiceQueries interface { - InvoiceQueries +// BatchedSQLInvoiceQueries is a version of the SQLInvoiceQueries that's capable +// of batched database operations. +type BatchedSQLInvoiceQueries interface { + SQLInvoiceQueries - BatchedTx[InvoiceQueries] + sqldb.BatchedTx[SQLInvoiceQueries] } -// InvoiceStore represents a storage backend. -type InvoiceStore struct { - db BatchedInvoiceQueries +// SQLStore represents a storage backend. +type SQLStore struct { + db BatchedSQLInvoiceQueries clock clock.Clock } -// NewInvoiceStore creates a new InvoiceStore instance given a open -// BatchedInvoiceQueries storage backend. -func NewInvoiceStore(db BatchedInvoiceQueries, - clock clock.Clock) *InvoiceStore { +// NewSQLStore creates a new SQLStore instance given a open +// BatchedSQLInvoiceQueries storage backend. +func NewSQLStore(db BatchedSQLInvoiceQueries, + clock clock.Clock) *SQLStore { - return &InvoiceStore{ + return &SQLStore{ db: db, clock: clock, } @@ -171,17 +171,17 @@ func NewInvoiceStore(db BatchedInvoiceQueries, // duplicate payment hashes. // // NOTE: A side effect of this function is that it sets AddIndex on newInvoice. -func (i *InvoiceStore) AddInvoice(ctx context.Context, - newInvoice *invpkg.Invoice, paymentHash lntypes.Hash) (uint64, error) { +func (i *SQLStore) AddInvoice(ctx context.Context, + newInvoice *Invoice, paymentHash lntypes.Hash) (uint64, error) { // Make sure this is a valid invoice before trying to store it in our // DB. - if err := invpkg.ValidateInvoice(newInvoice, paymentHash); err != nil { + if err := ValidateInvoice(newInvoice, paymentHash); err != nil { return 0, err } var ( - writeTxOpts InvoiceQueriesTxOptions + writeTxOpts SQLInvoiceQueriesTxOptions invoiceID int64 ) @@ -193,16 +193,18 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context, paymentRequestHash = h.Sum(nil) } - err := i.db.ExecTx(ctx, &writeTxOpts, func(db InvoiceQueries) error { + err := i.db.ExecTx(ctx, &writeTxOpts, func(db SQLInvoiceQueries) error { params := sqlc.InsertInvoiceParams{ Hash: paymentHash[:], - Memo: SQLStr(string(newInvoice.Memo)), + Memo: sqldb.SQLStr(string(newInvoice.Memo)), AmountMsat: int64(newInvoice.Terms.Value), // Note: BOLT12 invoices don't have a final cltv delta. - CltvDelta: SQLInt32(newInvoice.Terms.FinalCltvDelta), - Expiry: int32(newInvoice.Terms.Expiry), + CltvDelta: sqldb.SQLInt32( + newInvoice.Terms.FinalCltvDelta, + ), + Expiry: int32(newInvoice.Terms.Expiry), // Note: keysend invoices don't have a payment request. - PaymentRequest: SQLStr(string( + PaymentRequest: sqldb.SQLStr(string( newInvoice.PaymentRequest), ), PaymentRequestHash: paymentRequestHash, @@ -218,7 +220,7 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context, // HODL invoices. if newInvoice.Terms.PaymentPreimage != nil { preimage := *newInvoice.Terms.PaymentPreimage - if preimage == invpkg.UnknownPreimage { + if preimage == UnknownPreimage { return errors.New("cannot use all-zeroes " + "preimage") } @@ -226,7 +228,7 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context, } // Some non MPP payments may have the default (invalid) value. - if newInvoice.Terms.PaymentAddr != invpkg.BlankPayAddr { + if newInvoice.Terms.PaymentAddr != BlankPayAddr { params.PaymentAddr = newInvoice.Terms.PaymentAddr[:] } @@ -259,11 +261,11 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context, }) }) if err != nil { - mappedSQLErr := MapSQLError(err) - var uniqueConstraintErr *ErrSQLUniqueConstraintViolation + mappedSQLErr := sqldb.MapSQLError(err) + var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation if errors.As(mappedSQLErr, &uniqueConstraintErr) { // Add context to unique constraint errors. - return 0, invpkg.ErrDuplicateInvoice + return 0, ErrDuplicateInvoice } return 0, fmt.Errorf("unable to add invoice(%v): %w", @@ -277,15 +279,15 @@ func (i *InvoiceStore) AddInvoice(ctx context.Context, // fetchInvoice fetches the common invoice data and the AMP state for the // invoice with the given reference. -func (i *InvoiceStore) fetchInvoice(ctx context.Context, - db InvoiceQueries, ref invpkg.InvoiceRef) (*invpkg.Invoice, error) { +func (i *SQLStore) fetchInvoice(ctx context.Context, + db SQLInvoiceQueries, ref InvoiceRef) (*Invoice, error) { if ref.PayHash() == nil && ref.PayAddr() == nil && ref.SetID() == nil { - return nil, invpkg.ErrInvoiceNotFound + return nil, ErrInvoiceNotFound } var ( - invoice *invpkg.Invoice + invoice *Invoice params sqlc.GetInvoiceParams ) @@ -300,7 +302,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context, // all. Only allow lookups for payment address if it is not a blank // payment address, which is a special-cased value for legacy keysend // invoices. - if ref.PayAddr() != nil && *ref.PayAddr() != invpkg.BlankPayAddr { + if ref.PayAddr() != nil && *ref.PayAddr() != BlankPayAddr { params.PaymentAddr = ref.PayAddr()[:] } @@ -313,7 +315,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context, rows, err := db.GetInvoice(ctx, params) switch { case len(rows) == 0: - return nil, invpkg.ErrInvoiceNotFound + return nil, ErrInvoiceNotFound case len(rows) > 1: // In case the reference is ambiguous, meaning it matches more @@ -333,12 +335,12 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context, // Now that we got the invoice itself, fetch the HTLCs as requested by // the modifier. switch ref.Modifier() { - case invpkg.DefaultModifier: + case DefaultModifier: // By default we'll fetch all AMP HTLCs. setID = nil fetchAmpHtlcs = true - case invpkg.HtlcSetOnlyModifier: + case HtlcSetOnlyModifier: // In this case we'll fetch all AMP HTLCs for the // specified set id. if ref.SetID() == nil { @@ -349,7 +351,7 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context, setID = ref.SetID() fetchAmpHtlcs = true - case invpkg.HtlcSetBlankModifier: + case HtlcSetBlankModifier: // No need to fetch any HTLCs. setID = nil fetchAmpHtlcs = false @@ -377,9 +379,9 @@ func (i *InvoiceStore) fetchInvoice(ctx context.Context, // well. // //nolint:funlen -func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64, - setID *[32]byte, fetchHtlcs bool) (invpkg.AMPInvoiceState, - invpkg.HTLCSet, error) { +func fetchAmpState(ctx context.Context, db SQLInvoiceQueries, invoiceID int64, + setID *[32]byte, fetchHtlcs bool) (AMPInvoiceState, + HTLCSet, error) { var paramSetID []byte if setID != nil { @@ -398,7 +400,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64, return nil, nil, err } - ampState := make(map[invpkg.SetID]invpkg.InvoiceStateAMP) + ampState := make(map[SetID]InvoiceStateAMP) for _, row := range ampInvoiceRows { var rowSetID [32]byte @@ -413,8 +415,8 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64, } copy(rowSetID[:], row.SetID) - ampState[rowSetID] = invpkg.InvoiceStateAMP{ - State: invpkg.HtlcState(row.State), + ampState[rowSetID] = InvoiceStateAMP{ + State: HtlcState(row.State), SettleIndex: uint64(row.SettleIndex.Int64), SettleDate: settleDate, InvoiceKeys: make(map[models.CircuitKey]struct{}), @@ -457,7 +459,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64, return nil, nil, err } - ampHtlcs := make(map[models.CircuitKey]*invpkg.InvoiceHTLC) + ampHtlcs := make(map[models.CircuitKey]*InvoiceHTLC) for _, row := range ampHtlcRows { uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64) if err != nil { @@ -473,17 +475,17 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64, htlcID := uint64(row.HtlcID) - circuitKey := invpkg.CircuitKey{ + circuitKey := CircuitKey{ ChanID: chanID, HtlcID: htlcID, } - htlc := &invpkg.InvoiceHTLC{ + htlc := &InvoiceHTLC{ Amt: lnwire.MilliSatoshi(row.AmountMsat), AcceptHeight: uint32(row.AcceptHeight), AcceptTime: row.AcceptTime.Local(), Expiry: uint32(row.ExpiryHeight), - State: invpkg.HtlcState(row.State), + State: HtlcState(row.State), } if row.TotalMppMsat.Valid { @@ -522,7 +524,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64, rootShare, setID, uint32(row.ChildIndex), ) - htlc.AMP = &invpkg.InvoiceHtlcAMPData{ + htlc.AMP = &InvoiceHtlcAMPData{ Record: *ampRecord, } @@ -564,7 +566,7 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64, invoiceKeys[key] = struct{}{} - if htlc.State != invpkg.HtlcStateCanceled { //nolint: lll + if htlc.State != HtlcStateCanceled { //nolint: lll amtPaid += htlc.Amt } } @@ -584,22 +586,22 @@ func fetchAmpState(ctx context.Context, db InvoiceQueries, invoiceID int64, // ID for an AMP sub invoice. If the invoice is found, we'll return the complete // invoice. If the invoice is not found, then we'll return an ErrInvoiceNotFound // error. -func (i *InvoiceStore) LookupInvoice(ctx context.Context, - ref invpkg.InvoiceRef) (invpkg.Invoice, error) { +func (i *SQLStore) LookupInvoice(ctx context.Context, + ref InvoiceRef) (Invoice, error) { var ( - invoice *invpkg.Invoice + invoice *Invoice err error ) - readTxOpt := NewInvoiceQueryReadTx() - txErr := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { + readTxOpt := NewSQLInvoiceQueryReadTx() + txErr := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error { invoice, err = i.fetchInvoice(ctx, db, ref) return err }) if txErr != nil { - return invpkg.Invoice{}, txErr + return Invoice{}, txErr } return *invoice, nil @@ -608,14 +610,14 @@ func (i *InvoiceStore) LookupInvoice(ctx context.Context, // FetchPendingInvoices returns all the invoices that are currently in a // "pending" state. An invoice is pending if it has been created but not yet // settled or canceled. -func (i *InvoiceStore) FetchPendingInvoices(ctx context.Context) ( - map[lntypes.Hash]invpkg.Invoice, error) { +func (i *SQLStore) FetchPendingInvoices(ctx context.Context) ( + map[lntypes.Hash]Invoice, error) { - var invoices map[lntypes.Hash]invpkg.Invoice + var invoices map[lntypes.Hash]Invoice - readTxOpt := NewInvoiceQueryReadTx() - err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { - invoices = make(map[lntypes.Hash]invpkg.Invoice) + readTxOpt := NewSQLInvoiceQueryReadTx() + err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error { + invoices = make(map[lntypes.Hash]Invoice) limit := queryPaginationLimit return queryWithLimit(func(offset int) (int, error) { @@ -661,24 +663,24 @@ func (i *InvoiceStore) FetchPendingInvoices(ctx context.Context) ( // // NOTE: The index starts from 1. As a result we enforce that specifying a value // below the starting index value is a noop. -func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) ( - []invpkg.Invoice, error) { +func (i *SQLStore) InvoicesSettledSince(ctx context.Context, idx uint64) ( + []Invoice, error) { - var invoices []invpkg.Invoice + var invoices []Invoice if idx == 0 { return invoices, nil } - readTxOpt := NewInvoiceQueryReadTx() - err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { + readTxOpt := NewSQLInvoiceQueryReadTx() + err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error { invoices = nil settleIdx := idx limit := queryPaginationLimit err := queryWithLimit(func(offset int) (int, error) { params := sqlc.FilterInvoicesParams{ - SettleIndexGet: SQLInt64(settleIdx + 1), + SettleIndexGet: sqldb.SQLInt64(settleIdx + 1), NumLimit: int32(limit), NumOffset: int32(offset), } @@ -715,7 +717,7 @@ func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) ( // the provided index. ampInvoices, err := i.db.FetchSettledAMPSubInvoices( ctx, sqlc.FetchSettledAMPSubInvoicesParams{ - SettleIndexGet: SQLInt64(idx + 1), + SettleIndexGet: sqldb.SQLInt64(idx + 1), }, ) if err != nil { @@ -775,24 +777,24 @@ func (i *InvoiceStore) InvoicesSettledSince(ctx context.Context, idx uint64) ( // // NOTE: The index starts from 1. As a result we enforce that specifying a value // below the starting index value is a noop. -func (i *InvoiceStore) InvoicesAddedSince(ctx context.Context, idx uint64) ( - []invpkg.Invoice, error) { +func (i *SQLStore) InvoicesAddedSince(ctx context.Context, idx uint64) ( + []Invoice, error) { - var result []invpkg.Invoice + var result []Invoice if idx == 0 { return result, nil } - readTxOpt := NewInvoiceQueryReadTx() - err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { + readTxOpt := NewSQLInvoiceQueryReadTx() + err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error { result = nil addIdx := idx limit := queryPaginationLimit return queryWithLimit(func(offset int) (int, error) { params := sqlc.FilterInvoicesParams{ - AddIndexGet: SQLInt64(addIdx + 1), + AddIndexGet: sqldb.SQLInt64(addIdx + 1), NumLimit: int32(limit), NumOffset: int32(offset), } @@ -831,18 +833,18 @@ func (i *InvoiceStore) InvoicesAddedSince(ctx context.Context, idx uint64) ( // QueryInvoices allows a caller to query the invoice database for invoices // within the specified add index range. -func (i *InvoiceStore) QueryInvoices(ctx context.Context, - q invpkg.InvoiceQuery) (invpkg.InvoiceSlice, error) { +func (i *SQLStore) QueryInvoices(ctx context.Context, + q InvoiceQuery) (InvoiceSlice, error) { - var invoices []invpkg.Invoice + var invoices []Invoice if q.NumMaxInvoices == 0 { - return invpkg.InvoiceSlice{}, fmt.Errorf("max invoices must " + + return InvoiceSlice{}, fmt.Errorf("max invoices must " + "be non-zero") } - readTxOpt := NewInvoiceQueryReadTx() - err := i.db.ExecTx(ctx, &readTxOpt, func(db InvoiceQueries) error { + readTxOpt := NewSQLInvoiceQueryReadTx() + err := i.db.ExecTx(ctx, &readTxOpt, func(db SQLInvoiceQueries) error { invoices = nil limit := queryPaginationLimit @@ -856,7 +858,7 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context, if !q.Reversed { // The invoice with index offset id must not be // included in the results. - params.AddIndexGet = SQLInt64( + params.AddIndexGet = sqldb.SQLInt64( q.IndexOffset + uint64(offset) + 1, ) } @@ -867,13 +869,13 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context, // If the index offset was not set, we want to // fetch from the lastest invoice. if idx == 0 { - params.AddIndexLet = SQLInt64( + params.AddIndexLet = sqldb.SQLInt64( int64(math.MaxInt64), ) } else { // The invoice with index offset id must // not be included in the results. - params.AddIndexLet = SQLInt64( + params.AddIndexLet = sqldb.SQLInt64( idx - int32(offset) - 1, ) } @@ -882,13 +884,13 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context, } if q.CreationDateStart != 0 { - params.CreatedAfter = SQLTime( + params.CreatedAfter = sqldb.SQLTime( time.Unix(q.CreationDateStart, 0).UTC(), ) } if q.CreationDateEnd != 0 { - params.CreatedBefore = SQLTime( + params.CreatedBefore = sqldb.SQLTime( time.Unix(q.CreationDateEnd, 0).UTC(), ) } @@ -919,12 +921,12 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context, }, limit) }) if err != nil { - return invpkg.InvoiceSlice{}, fmt.Errorf("unable to query "+ + return InvoiceSlice{}, fmt.Errorf("unable to query "+ "invoices: %w", err) } if len(invoices) == 0 { - return invpkg.InvoiceSlice{ + return InvoiceSlice{ InvoiceQuery: q, }, nil } @@ -941,7 +943,7 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context, } } - res := invpkg.InvoiceSlice{ + res := InvoiceSlice{ InvoiceQuery: q, Invoices: invoices, FirstIndexOffset: invoices[0].AddIndex, @@ -954,15 +956,15 @@ func (i *InvoiceStore) QueryInvoices(ctx context.Context, // sqlInvoiceUpdater is the implementation of the InvoiceUpdater interface using // a SQL database as the backend. type sqlInvoiceUpdater struct { - db InvoiceQueries + db SQLInvoiceQueries ctx context.Context //nolint:containedctx - invoice *invpkg.Invoice + invoice *Invoice updateTime time.Time } // AddHtlc adds a new htlc to the invoice. func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey, - newHtlc *invpkg.InvoiceHTLC) error { + newHtlc *InvoiceHTLC) error { htlcPrimaryKeyID, err := s.db.InsertInvoiceHTLC( s.ctx, sqlc.InsertInvoiceHTLCParams{ @@ -1012,10 +1014,10 @@ func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey, }, ) if err != nil { - mappedSQLErr := MapSQLError(err) - var uniqueConstraintErr *ErrSQLUniqueConstraintViolation + mappedSQLErr := sqldb.MapSQLError(err) + var uniqueConstraintErr *sqldb.ErrSQLUniqueConstraintViolation //nolint:lll if errors.As(mappedSQLErr, &uniqueConstraintErr) { - return invpkg.ErrDuplicateSetID{ + return ErrDuplicateSetID{ SetID: setID, } } @@ -1072,7 +1074,7 @@ func (s *sqlInvoiceUpdater) AddHtlc(circuitKey models.CircuitKey, // ResolveHtlc marks an htlc as resolved with the given state. func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey, - state invpkg.HtlcState, resolveTime time.Time) error { + state HtlcState, resolveTime time.Time) error { return s.db.UpdateInvoiceHTLC(s.ctx, sqlc.UpdateInvoiceHTLCParams{ HtlcID: int64(circuitKey.HtlcID), @@ -1081,7 +1083,7 @@ func (s *sqlInvoiceUpdater) ResolveHtlc(circuitKey models.CircuitKey, ), InvoiceID: int64(s.invoice.AddIndex), State: int16(state), - ResolveTime: SQLTime(resolveTime.UTC()), + ResolveTime: sqldb.SQLTime(resolveTime.UTC()), }) } @@ -1107,7 +1109,7 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte, return err } if rowsAffected == 0 { - return invpkg.ErrInvoiceNotFound + return ErrInvoiceNotFound } return nil @@ -1115,7 +1117,7 @@ func (s *sqlInvoiceUpdater) AddAmpHtlcPreimage(setID [32]byte, // UpdateInvoiceState updates the invoice state to the new state. func (s *sqlInvoiceUpdater) UpdateInvoiceState( - newState invpkg.ContractState, preimage *lntypes.Preimage) error { + newState ContractState, preimage *lntypes.Preimage) error { var ( settleIndex sql.NullInt64 @@ -1123,16 +1125,16 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState( ) switch newState { - case invpkg.ContractSettled: + case ContractSettled: nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx) if err != nil { return err } - settleIndex = SQLInt64(nextSettleIndex) + settleIndex = sqldb.SQLInt64(nextSettleIndex) // If the invoice is settled, we'll also update the settle time. - settledAt = SQLTime(s.updateTime.UTC()) + settledAt = sqldb.SQLTime(s.updateTime.UTC()) err = s.db.OnInvoiceSettled( s.ctx, sqlc.OnInvoiceSettledParams{ @@ -1144,7 +1146,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState( return err } - case invpkg.ContractCanceled: + case ContractCanceled: err := s.db.OnInvoiceCanceled( s.ctx, sqlc.OnInvoiceCanceledParams{ AddedAt: s.updateTime.UTC(), @@ -1177,7 +1179,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceState( } if rowsAffected == 0 { - return invpkg.ErrInvoiceNotFound + return ErrInvoiceNotFound } if settleIndex.Valid { @@ -1205,7 +1207,7 @@ func (s *sqlInvoiceUpdater) UpdateInvoiceAmtPaid( // UpdateAmpState updates the state of the AMP sub invoice identified by the // setID. func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte, - newState invpkg.InvoiceStateAMP, _ models.CircuitKey) error { + newState InvoiceStateAMP, _ models.CircuitKey) error { var ( settleIndex sql.NullInt64 @@ -1213,16 +1215,16 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte, ) switch newState.State { - case invpkg.HtlcStateSettled: + case HtlcStateSettled: nextSettleIndex, err := s.db.NextInvoiceSettleIndex(s.ctx) if err != nil { return err } - settleIndex = SQLInt64(nextSettleIndex) + settleIndex = sqldb.SQLInt64(nextSettleIndex) // If the invoice is settled, we'll also update the settle time. - settledAt = SQLTime(s.updateTime.UTC()) + settledAt = sqldb.SQLTime(s.updateTime.UTC()) err = s.db.OnAMPSubInvoiceSettled( s.ctx, sqlc.OnAMPSubInvoiceSettledParams{ @@ -1235,7 +1237,7 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte, return err } - case invpkg.HtlcStateCanceled: + case HtlcStateCanceled: err := s.db.OnAMPSubInvoiceCanceled( s.ctx, sqlc.OnAMPSubInvoiceCanceledParams{ AddedAt: s.updateTime.UTC(), @@ -1266,7 +1268,7 @@ func (s *sqlInvoiceUpdater) UpdateAmpState(setID [32]byte, // Finalize finalizes the update before it is written to the database. Note that // we don't use this directly in the SQL implementation, so the function is just // a stub. -func (s *sqlInvoiceUpdater) Finalize(_ invpkg.UpdateType) error { +func (s *sqlInvoiceUpdater) Finalize(_ UpdateType) error { return nil } @@ -1277,14 +1279,14 @@ func (s *sqlInvoiceUpdater) Finalize(_ invpkg.UpdateType) error { // The update is performed inside the same database transaction that fetches the // invoice and is therefore atomic. The fields to update are controlled by the // supplied callback. -func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef, - _ *invpkg.SetID, callback invpkg.InvoiceUpdateCallback) ( - *invpkg.Invoice, error) { +func (i *SQLStore) UpdateInvoice(ctx context.Context, ref InvoiceRef, + _ *SetID, callback InvoiceUpdateCallback) ( + *Invoice, error) { - var updatedInvoice *invpkg.Invoice + var updatedInvoice *Invoice - txOpt := InvoiceQueriesTxOptions{readOnly: false} - txErr := i.db.ExecTx(ctx, &txOpt, func(db InvoiceQueries) error { + txOpt := SQLInvoiceQueriesTxOptions{readOnly: false} + txErr := i.db.ExecTx(ctx, &txOpt, func(db SQLInvoiceQueries) error { invoice, err := i.fetchInvoice(ctx, db, ref) if err != nil { return err @@ -1299,7 +1301,7 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef, } payHash := ref.PayHash() - updatedInvoice, err = invpkg.UpdateInvoice( + updatedInvoice, err = UpdateInvoice( payHash, invoice, updateTime, callback, updater, ) @@ -1308,7 +1310,7 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef, if txErr != nil { // If the invoice is already settled, we'll return the // (unchanged) invoice and the ErrInvoiceAlreadySettled error. - if errors.Is(txErr, invpkg.ErrInvoiceAlreadySettled) { + if errors.Is(txErr, ErrInvoiceAlreadySettled) { return updatedInvoice, txErr } @@ -1320,8 +1322,8 @@ func (i *InvoiceStore) UpdateInvoice(ctx context.Context, ref invpkg.InvoiceRef, // DeleteInvoice attempts to delete the passed invoices and all their related // data from the database in one transaction. -func (i *InvoiceStore) DeleteInvoice(ctx context.Context, - invoicesToDelete []invpkg.InvoiceDeleteRef) error { +func (i *SQLStore) DeleteInvoice(ctx context.Context, + invoicesToDelete []InvoiceDeleteRef) error { // All the InvoiceDeleteRef instances include the add index of the // invoice. The rest was added to ensure that the invoices were deleted @@ -1334,15 +1336,17 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context, } } - var writeTxOpt InvoiceQueriesTxOptions - err := i.db.ExecTx(ctx, &writeTxOpt, func(db InvoiceQueries) error { + var writeTxOpt SQLInvoiceQueriesTxOptions + err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error { for _, ref := range invoicesToDelete { params := sqlc.DeleteInvoiceParams{ - AddIndex: SQLInt64(ref.AddIndex), + AddIndex: sqldb.SQLInt64(ref.AddIndex), } if ref.SettleIndex != 0 { - params.SettleIndex = SQLInt64(ref.SettleIndex) + params.SettleIndex = sqldb.SQLInt64( + ref.SettleIndex, + ) } if ref.PayHash != lntypes.ZeroHash { @@ -1361,7 +1365,7 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context, } if rowsAffected == 0 { return fmt.Errorf("%w: %v", - invpkg.ErrInvoiceNotFound, ref.AddIndex) + ErrInvoiceNotFound, ref.AddIndex) } } @@ -1376,9 +1380,9 @@ func (i *InvoiceStore) DeleteInvoice(ctx context.Context, } // DeleteCanceledInvoices removes all canceled invoices from the database. -func (i *InvoiceStore) DeleteCanceledInvoices(ctx context.Context) error { - var writeTxOpt InvoiceQueriesTxOptions - err := i.db.ExecTx(ctx, &writeTxOpt, func(db InvoiceQueries) error { +func (i *SQLStore) DeleteCanceledInvoices(ctx context.Context) error { + var writeTxOpt SQLInvoiceQueriesTxOptions + err := i.db.ExecTx(ctx, &writeTxOpt, func(db SQLInvoiceQueries) error { _, err := db.DeleteCanceledInvoices(ctx) if err != nil { return fmt.Errorf("unable to delete canceled "+ @@ -1398,9 +1402,9 @@ func (i *InvoiceStore) DeleteCanceledInvoices(ctx context.Context) error { // invoice is AMP and the setID is not nil, then it will also fetch the AMP // state and HTLCs for the given setID, otherwise for all AMP sub invoices of // the invoice. If fetchAmpHtlcs is true, it will also fetch the AMP HTLCs. -func fetchInvoiceData(ctx context.Context, db InvoiceQueries, +func fetchInvoiceData(ctx context.Context, db SQLInvoiceQueries, row sqlc.Invoice, setID *[32]byte, fetchAmpHtlcs bool) (*lntypes.Hash, - *invpkg.Invoice, error) { + *Invoice, error) { // Unmarshal the common data. hash, invoice, err := unmarshalInvoice(row) @@ -1444,7 +1448,7 @@ func fetchInvoiceData(ctx context.Context, db InvoiceQueries, invoice.Htlcs = htlcs var amountPaid lnwire.MilliSatoshi for _, htlc := range htlcs { - if htlc.State == invpkg.HtlcStateSettled { + if htlc.State == HtlcStateSettled { amountPaid += htlc.Amt } } @@ -1455,7 +1459,7 @@ func fetchInvoiceData(ctx context.Context, db InvoiceQueries, } // getInvoiceFeatures fetches the invoice features for the given invoice id. -func getInvoiceFeatures(ctx context.Context, db InvoiceQueries, +func getInvoiceFeatures(ctx context.Context, db SQLInvoiceQueries, invoiceID int64) (*lnwire.FeatureVector, error) { rows, err := db.GetInvoiceFeatures(ctx, invoiceID) @@ -1473,8 +1477,8 @@ func getInvoiceFeatures(ctx context.Context, db InvoiceQueries, } // getInvoiceHtlcs fetches the invoice htlcs for the given invoice id. -func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries, - invoiceID int64) (map[invpkg.CircuitKey]*invpkg.InvoiceHTLC, error) { +func getInvoiceHtlcs(ctx context.Context, db SQLInvoiceQueries, + invoiceID int64) (map[CircuitKey]*InvoiceHTLC, error) { htlcRows, err := db.GetInvoiceHTLCs(ctx, invoiceID) if err != nil { @@ -1505,7 +1509,7 @@ func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries, cr[row.HtlcID][uint64(row.Key)] = value } - htlcs := make(map[invpkg.CircuitKey]*invpkg.InvoiceHTLC, len(htlcRows)) + htlcs := make(map[CircuitKey]*InvoiceHTLC, len(htlcRows)) for _, row := range htlcRows { circuiteKey, htlc, err := unmarshalInvoiceHTLC(row) @@ -1527,7 +1531,7 @@ func getInvoiceHtlcs(ctx context.Context, db InvoiceQueries, } // unmarshalInvoice converts an InvoiceRow to an Invoice. -func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice, +func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *Invoice, error) { var ( @@ -1576,13 +1580,13 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice, cltvDelta = row.CltvDelta.Int32 } - invoice := &invpkg.Invoice{ + invoice := &Invoice{ SettleIndex: uint64(settleIndex), SettleDate: settledAt, Memo: memo, PaymentRequest: paymentRequest, CreationDate: row.CreatedAt.Local(), - Terms: invpkg.ContractTerm{ + Terms: ContractTerm{ FinalCltvDelta: cltvDelta, Expiry: time.Duration(row.Expiry), PaymentPreimage: preimage, @@ -1590,10 +1594,10 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice, PaymentAddr: paymentAddr, }, AddIndex: uint64(row.ID), - State: invpkg.ContractState(row.State), + State: ContractState(row.State), AmtPaid: lnwire.MilliSatoshi(row.AmountPaidMsat), - Htlcs: make(map[models.CircuitKey]*invpkg.InvoiceHTLC), - AMPState: invpkg.AMPInvoiceState{}, + Htlcs: make(map[models.CircuitKey]*InvoiceHTLC), + AMPState: AMPInvoiceState{}, HodlInvoice: row.IsHodl, } @@ -1601,34 +1605,34 @@ func unmarshalInvoice(row sqlc.Invoice) (*lntypes.Hash, *invpkg.Invoice, } // unmarshalInvoiceHTLC converts an sqlc.InvoiceHtlc to an InvoiceHTLC. -func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (invpkg.CircuitKey, - *invpkg.InvoiceHTLC, error) { +func unmarshalInvoiceHTLC(row sqlc.InvoiceHtlc) (CircuitKey, + *InvoiceHTLC, error) { uint64ChanID, err := strconv.ParseUint(row.ChanID, 10, 64) if err != nil { - return invpkg.CircuitKey{}, nil, err + return CircuitKey{}, nil, err } chanID := lnwire.NewShortChanIDFromInt(uint64ChanID) if row.HtlcID < 0 { - return invpkg.CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+ + return CircuitKey{}, nil, fmt.Errorf("invalid uint64 "+ "value: %v", row.HtlcID) } htlcID := uint64(row.HtlcID) - circuitKey := invpkg.CircuitKey{ + circuitKey := CircuitKey{ ChanID: chanID, HtlcID: htlcID, } - htlc := &invpkg.InvoiceHTLC{ + htlc := &InvoiceHTLC{ Amt: lnwire.MilliSatoshi(row.AmountMsat), AcceptHeight: uint32(row.AcceptHeight), AcceptTime: row.AcceptTime.Local(), Expiry: uint32(row.ExpiryHeight), - State: invpkg.HtlcState(row.State), + State: HtlcState(row.State), } if row.TotalMppMsat.Valid {