diff --git a/sweep/store.go b/sweep/store.go index 375436ca3..cfab66381 100644 --- a/sweep/store.go +++ b/sweep/store.go @@ -128,6 +128,13 @@ type SweeperStore interface { // ListSweeps lists all the sweeps we have successfully published. ListSweeps() ([]chainhash.Hash, error) + + // GetTx queries the database to find the tx that matches the given + // txid. Returns ErrTxNotFound if it cannot be found. + GetTx(hash chainhash.Hash) (*TxRecord, error) + + // DeleteTx removes a tx specified by the hash from the store. + DeleteTx(hash chainhash.Hash) error } type sweeperStore struct { @@ -322,5 +329,66 @@ func (s *sweeperStore) ListSweeps() ([]chainhash.Hash, error) { return sweepTxns, nil } +// GetTx queries the database to find the tx that matches the given txid. +// Returns ErrTxNotFound if it cannot be found. +func (s *sweeperStore) GetTx(txid chainhash.Hash) (*TxRecord, error) { + // Create a record. + tr := &TxRecord{} + + var err error + err = kvdb.View(s.db, func(tx kvdb.RTx) error { + txHashesBucket := tx.ReadBucket(txHashesBucketKey) + if txHashesBucket == nil { + return errNoTxHashesBucket + } + + txBytes := txHashesBucket.Get(txid[:]) + if txBytes == nil { + return ErrTxNotFound + } + + // For old records, we'd get an empty byte slice here. We can + // assume it's already been published. Although it's possible + // to calculate the fees and fee rate used here, we skip it as + // it's unlikely we'd perform RBF on these old sweeping + // transactions. + // + // TODO(yy): remove this check once migration is added. + if len(txBytes) == 0 { + tr.Published = true + return nil + } + + tr, err = deserializeTxRecord(bytes.NewReader(txBytes)) + if err != nil { + return err + } + + return nil + }, func() { + tr = &TxRecord{} + }) + if err != nil { + return nil, err + } + + // Attach the txid to the record. + tr.Txid = txid + + return tr, nil +} + +// DeleteTx removes the given tx from db. +func (s *sweeperStore) DeleteTx(txid chainhash.Hash) error { + return kvdb.Update(s.db, func(tx kvdb.RwTx) error { + txHashesBucket := tx.ReadWriteBucket(txHashesBucketKey) + if txHashesBucket == nil { + return errNoTxHashesBucket + } + + return txHashesBucket.Delete(txid[:]) + }, func() {}) +} + // Compile-time constraint to ensure sweeperStore implements SweeperStore. var _ SweeperStore = (*sweeperStore)(nil) diff --git a/sweep/store_mock.go b/sweep/store_mock.go index 8263cb3a8..16f7714a9 100644 --- a/sweep/store_mock.go +++ b/sweep/store_mock.go @@ -41,5 +41,16 @@ func (s *MockSweeperStore) ListSweeps() ([]chainhash.Hash, error) { return txns, nil } +// GetTx queries the database to find the tx that matches the given txid. +// Returns ErrTxNotFound if it cannot be found. +func (s *MockSweeperStore) GetTx(hash chainhash.Hash) (*TxRecord, error) { + return nil, ErrTxNotFound +} + +// DeleteTx removes the given tx from db. +func (s *MockSweeperStore) DeleteTx(txid chainhash.Hash) error { + return nil +} + // Compile-time constraint to ensure MockSweeperStore implements SweeperStore. var _ SweeperStore = (*MockSweeperStore)(nil) diff --git a/sweep/store_test.go b/sweep/store_test.go index 4abf65871..7cfc649c9 100644 --- a/sweep/store_test.go +++ b/sweep/store_test.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/kvdb" "github.com/stretchr/testify/require" ) @@ -138,3 +139,108 @@ func TestTxRecord(t *testing.T) { // Assert the deserialized record is equal to the original. require.Equal(t, tr, result) } + +// TestGetTx asserts that the GetTx method behaves as expected. +func TestGetTx(t *testing.T) { + t.Parallel() + + cdb, err := channeldb.MakeTestDB(t) + require.NoError(t, err) + + // Create a testing store. + chain := chainhash.Hash{} + store, err := NewSweeperStore(cdb, &chain) + require.NoError(t, err) + + // Create a testing record. + txid := chainhash.Hash{1, 2, 3} + tr := &TxRecord{ + Txid: txid, + FeeRate: 1000, + Fee: 10000, + Published: true, + } + + // Assert we can store this tx record. + err = store.StoreTx(tr) + require.NoError(t, err) + + // Assert we can query the tx record. + result, err := store.GetTx(txid) + require.NoError(t, err) + require.Equal(t, tr, result) + + // Assert we get an error when querying a non-existing tx. + _, err = store.GetTx(chainhash.Hash{4, 5, 6}) + require.ErrorIs(t, ErrTxNotFound, err) +} + +// TestGetTxCompatible asserts that when there's old tx record data in the +// database it can be successfully queried. +func TestGetTxCompatible(t *testing.T) { + t.Parallel() + + cdb, err := channeldb.MakeTestDB(t) + require.NoError(t, err) + + // Create a testing store. + chain := chainhash.Hash{} + store, err := NewSweeperStore(cdb, &chain) + require.NoError(t, err) + + // Create a testing txid. + txid := chainhash.Hash{0, 1, 2, 3} + + // Create a record using the old format "hash -> empty byte slice". + err = kvdb.Update(cdb, func(tx kvdb.RwTx) error { + txHashesBucket := tx.ReadWriteBucket(txHashesBucketKey) + return txHashesBucket.Put(txid[:], []byte{}) + }, func() {}) + require.NoError(t, err) + + // Assert we can query the tx record. + result, err := store.GetTx(txid) + require.NoError(t, err) + require.Equal(t, txid, result.Txid) + + // Assert the Published field is true. + require.True(t, result.Published) +} + +// TestDeleteTx asserts that the DeleteTx method behaves as expected. +func TestDeleteTx(t *testing.T) { + t.Parallel() + + cdb, err := channeldb.MakeTestDB(t) + require.NoError(t, err) + + // Create a testing store. + chain := chainhash.Hash{} + store, err := NewSweeperStore(cdb, &chain) + require.NoError(t, err) + + // Create a testing record. + txid := chainhash.Hash{1, 2, 3} + tr := &TxRecord{ + Txid: txid, + FeeRate: 1000, + Fee: 10000, + Published: true, + } + + // Assert we can store this tx record. + err = store.StoreTx(tr) + require.NoError(t, err) + + // Assert we can delete the tx record. + err = store.DeleteTx(txid) + require.NoError(t, err) + + // Query it again should give us an error. + _, err = store.GetTx(txid) + require.ErrorIs(t, ErrTxNotFound, err) + + // Assert deleting a non-existing tx doesn't return an error. + err = store.DeleteTx(chainhash.Hash{4, 5, 6}) + require.NoError(t, err) +}