From 6135e0553e6e58fcf506700991fa178f2c50a266 Mon Sep 17 00:00:00 2001 From: Ava Chow Date: Mon, 12 May 2025 13:29:50 -0700 Subject: [PATCH] wallet, rpc: Move (Un)LockCoin WalletBatch creation out of RPC If the locked coin needs to be persisted to the wallet database, insteead of having the RPC figure out when to create a WalletBatch and having LockCoin's behavior depend on it, have LockCoin take whether to persist as a parameter so it makes the batch. Since unlocking a persisted locked coin requires a database write as well, we need to track whether the locked coin was persisted to the wallet database so that it can erase the locked coin when necessary. Keeping track of whether a locked coin was persisted is also useful information for future PRs. --- src/wallet/interfaces.cpp | 6 ++-- src/wallet/rpc/coins.cpp | 8 ++--- src/wallet/rpc/spend.cpp | 2 +- src/wallet/spend.cpp | 2 +- src/wallet/test/wallet_tests.cpp | 4 +-- src/wallet/wallet.cpp | 57 +++++++++++++++++--------------- src/wallet/wallet.h | 15 +++++---- src/wallet/walletdb.cpp | 2 +- 8 files changed, 49 insertions(+), 47 deletions(-) diff --git a/src/wallet/interfaces.cpp b/src/wallet/interfaces.cpp index 0048c025a29..ad75e69eccc 100644 --- a/src/wallet/interfaces.cpp +++ b/src/wallet/interfaces.cpp @@ -249,14 +249,12 @@ public: bool lockCoin(const COutPoint& output, const bool write_to_db) override { LOCK(m_wallet->cs_wallet); - std::unique_ptr batch = write_to_db ? std::make_unique(m_wallet->GetDatabase()) : nullptr; - return m_wallet->LockCoin(output, batch.get()); + return m_wallet->LockCoin(output, write_to_db); } bool unlockCoin(const COutPoint& output) override { LOCK(m_wallet->cs_wallet); - std::unique_ptr batch = std::make_unique(m_wallet->GetDatabase()); - return m_wallet->UnlockCoin(output, batch.get()); + return m_wallet->UnlockCoin(output); } bool isLockedCoin(const COutPoint& output) override { diff --git a/src/wallet/rpc/coins.cpp b/src/wallet/rpc/coins.cpp index cce9b26babe..9351259a5a5 100644 --- a/src/wallet/rpc/coins.cpp +++ b/src/wallet/rpc/coins.cpp @@ -356,16 +356,12 @@ RPCHelpMan lockunspent() outputs.push_back(outpt); } - std::unique_ptr batch = nullptr; - // Unlock is always persistent - if (fUnlock || persistent) batch = std::make_unique(pwallet->GetDatabase()); - // Atomically set (un)locked status for the outputs. for (const COutPoint& outpt : outputs) { if (fUnlock) { - if (!pwallet->UnlockCoin(outpt, batch.get())) throw JSONRPCError(RPC_WALLET_ERROR, "Unlocking coin failed"); + if (!pwallet->UnlockCoin(outpt)) throw JSONRPCError(RPC_WALLET_ERROR, "Unlocking coin failed"); } else { - if (!pwallet->LockCoin(outpt, batch.get())) throw JSONRPCError(RPC_WALLET_ERROR, "Locking coin failed"); + if (!pwallet->LockCoin(outpt, persistent)) throw JSONRPCError(RPC_WALLET_ERROR, "Locking coin failed"); } } diff --git a/src/wallet/rpc/spend.cpp b/src/wallet/rpc/spend.cpp index 4c4b6288369..4421822dd19 100644 --- a/src/wallet/rpc/spend.cpp +++ b/src/wallet/rpc/spend.cpp @@ -1579,7 +1579,7 @@ RPCHelpMan sendall() const bool lock_unspents{options.exists("lock_unspents") ? options["lock_unspents"].get_bool() : false}; if (lock_unspents) { for (const CTxIn& txin : rawTx.vin) { - pwallet->LockCoin(txin.prevout); + pwallet->LockCoin(txin.prevout, /*persist=*/false); } } diff --git a/src/wallet/spend.cpp b/src/wallet/spend.cpp index a330c31a87a..13f7d3d61a7 100644 --- a/src/wallet/spend.cpp +++ b/src/wallet/spend.cpp @@ -1467,7 +1467,7 @@ util::Result FundTransaction(CWallet& wallet, const CM if (lockUnspents) { for (const CTxIn& txin : res->tx->vin) { - wallet.LockCoin(txin.prevout); + wallet.LockCoin(txin.prevout, /*persist=*/false); } } diff --git a/src/wallet/test/wallet_tests.cpp b/src/wallet/test/wallet_tests.cpp index 966c6d2c4ba..fa141696616 100644 --- a/src/wallet/test/wallet_tests.cpp +++ b/src/wallet/test/wallet_tests.cpp @@ -458,7 +458,7 @@ BOOST_FIXTURE_TEST_CASE(ListCoinsTest, ListCoinsTestingSetup) for (const auto& group : list) { for (const auto& coin : group.second) { LOCK(wallet->cs_wallet); - wallet->LockCoin(coin.outpoint); + wallet->LockCoin(coin.outpoint, /*persist=*/false); } } { @@ -486,7 +486,7 @@ void TestCoinsResult(ListCoinsTest& context, OutputType out_type, CAmount amount filter.skip_locked = false; CoinsResult available_coins = AvailableCoins(*context.wallet, nullptr, std::nullopt, filter); // Lock outputs so they are not spent in follow-up transactions - for (uint32_t i = 0; i < wtx.tx->vout.size(); i++) context.wallet->LockCoin({wtx.GetHash(), i}); + for (uint32_t i = 0; i < wtx.tx->vout.size(); i++) context.wallet->LockCoin({wtx.GetHash(), i}, /*persist=*/false); for (const auto& [type, size] : expected_coins_sizes) BOOST_CHECK_EQUAL(size, available_coins.coins[type].size()); } diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 4a4aa837eaa..6a40cfe97eb 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -785,16 +785,11 @@ bool CWallet::IsSpent(const COutPoint& outpoint) const return false; } -void CWallet::AddToSpends(const COutPoint& outpoint, const Txid& txid, WalletBatch* batch) +void CWallet::AddToSpends(const COutPoint& outpoint, const Txid& txid) { mapTxSpends.insert(std::make_pair(outpoint, txid)); - if (batch) { - UnlockCoin(outpoint, batch); - } else { - WalletBatch temp_batch(GetDatabase()); - UnlockCoin(outpoint, &temp_batch); - } + UnlockCoin(outpoint); std::pair range; range = mapTxSpends.equal_range(outpoint); @@ -802,13 +797,13 @@ void CWallet::AddToSpends(const COutPoint& outpoint, const Txid& txid, WalletBat } -void CWallet::AddToSpends(const CWalletTx& wtx, WalletBatch* batch) +void CWallet::AddToSpends(const CWalletTx& wtx) { if (wtx.IsCoinBase()) // Coinbases don't spend anything! return; for (const CTxIn& txin : wtx.tx->vin) - AddToSpends(txin.prevout, wtx.GetHash(), batch); + AddToSpends(txin.prevout, wtx.GetHash()); } bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) @@ -1058,7 +1053,7 @@ CWalletTx* CWallet::AddToWallet(CTransactionRef tx, const TxState& state, const wtx.nOrderPos = IncOrderPosNext(&batch); wtx.m_it_wtxOrdered = wtxOrdered.insert(std::make_pair(wtx.nOrderPos, &wtx)); wtx.nTimeSmart = ComputeTimeSmart(wtx, rescanning_old_block); - AddToSpends(wtx, &batch); + AddToSpends(wtx); // Update birth time when tx time is older than it. MaybeUpdateBirthTime(wtx.GetTxTime()); @@ -2622,22 +2617,34 @@ util::Result CWallet::DisplayAddress(const CTxDestination& dest) return util::Error{_("There is no ScriptPubKeyManager for this address")}; } -bool CWallet::LockCoin(const COutPoint& output, WalletBatch* batch) +void CWallet::LoadLockedCoin(const COutPoint& coin, bool persistent) { AssertLockHeld(cs_wallet); - setLockedCoins.insert(output); - if (batch) { - return batch->WriteLockedUTXO(output); + m_locked_coins.emplace(coin, persistent); +} + +bool CWallet::LockCoin(const COutPoint& output, bool persist) +{ + AssertLockHeld(cs_wallet); + LoadLockedCoin(output, persist); + if (persist) { + WalletBatch batch(GetDatabase()); + return batch.WriteLockedUTXO(output); } return true; } -bool CWallet::UnlockCoin(const COutPoint& output, WalletBatch* batch) +bool CWallet::UnlockCoin(const COutPoint& output) { AssertLockHeld(cs_wallet); - bool was_locked = setLockedCoins.erase(output); - if (batch && was_locked) { - return batch->EraseLockedUTXO(output); + auto locked_coin_it = m_locked_coins.find(output); + if (locked_coin_it != m_locked_coins.end()) { + bool persisted = locked_coin_it->second; + m_locked_coins.erase(locked_coin_it); + if (persisted) { + WalletBatch batch(GetDatabase()); + return batch.EraseLockedUTXO(output); + } } return true; } @@ -2647,26 +2654,24 @@ bool CWallet::UnlockAllCoins() AssertLockHeld(cs_wallet); bool success = true; WalletBatch batch(GetDatabase()); - for (auto it = setLockedCoins.begin(); it != setLockedCoins.end(); ++it) { - success &= batch.EraseLockedUTXO(*it); + for (const auto& [coin, persistent] : m_locked_coins) { + if (persistent) success = success && batch.EraseLockedUTXO(coin); } - setLockedCoins.clear(); + m_locked_coins.clear(); return success; } bool CWallet::IsLockedCoin(const COutPoint& output) const { AssertLockHeld(cs_wallet); - return setLockedCoins.count(output) > 0; + return m_locked_coins.count(output) > 0; } void CWallet::ListLockedCoins(std::vector& vOutpts) const { AssertLockHeld(cs_wallet); - for (std::set::iterator it = setLockedCoins.begin(); - it != setLockedCoins.end(); it++) { - COutPoint outpt = (*it); - vOutpts.push_back(outpt); + for (const auto& [coin, _] : m_locked_coins) { + vOutpts.push_back(coin); } } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index fbc3bed2ab6..7fe557f71d7 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -333,8 +333,8 @@ private: */ typedef std::unordered_multimap TxSpends; TxSpends mapTxSpends GUARDED_BY(cs_wallet); - void AddToSpends(const COutPoint& outpoint, const Txid& txid, WalletBatch* batch = nullptr) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); - void AddToSpends(const CWalletTx& wtx, WalletBatch* batch = nullptr) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); + void AddToSpends(const COutPoint& outpoint, const Txid& txid) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); + void AddToSpends(const CWalletTx& wtx) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); /** * Add a transaction to the wallet, or update it. confirm.block_* should @@ -497,8 +497,10 @@ public: /** Set of Coins owned by this wallet that we won't try to spend from. A * Coin may be locked if it has already been used to fund a transaction * that hasn't confirmed yet. We wouldn't consider the Coin spent already, - * but also shouldn't try to use it again. */ - std::set setLockedCoins GUARDED_BY(cs_wallet); + * but also shouldn't try to use it again. + * bool to track whether this locked coin is persisted to disk. + */ + std::map m_locked_coins GUARDED_BY(cs_wallet); /** Registered interfaces::Chain::Notifications handler. */ std::unique_ptr m_chain_notifications_handler; @@ -546,8 +548,9 @@ public: util::Result DisplayAddress(const CTxDestination& dest) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); bool IsLockedCoin(const COutPoint& output) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); - bool LockCoin(const COutPoint& output, WalletBatch* batch = nullptr) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); - bool UnlockCoin(const COutPoint& output, WalletBatch* batch = nullptr) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); + void LoadLockedCoin(const COutPoint& coin, bool persistent) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); + bool LockCoin(const COutPoint& output, bool persist) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); + bool UnlockCoin(const COutPoint& output) EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); bool UnlockAllCoins() EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); void ListLockedCoins(std::vector& vOutpts) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index 3bb5db71f5e..0f0e228fc80 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -1072,7 +1072,7 @@ static DBErrors LoadTxRecords(CWallet* pwallet, DatabaseBatch& batch, std::vecto uint32_t n; key >> hash; key >> n; - pwallet->LockCoin(COutPoint(hash, n)); + pwallet->LoadLockedCoin(COutPoint(hash, n), /*persistent=*/true); return DBErrors::LOAD_OK; }); result = std::max(result, locked_utxo_res.m_result);