Simplify removeRecursive

This commit is contained in:
Suhas Daftuar
2025-02-05 09:11:40 -05:00
parent 01d8520038
commit a5a7905d83
2 changed files with 37 additions and 23 deletions

View File

@@ -308,35 +308,41 @@ CTxMemPool::txiter CTxMemPool::CalculateDescendants(const CTxMemPoolEntry& entry
return mapTx.iterator_to(entry);
}
void CTxMemPool::removeRecursive(CTxMemPool::txiter to_remove, MemPoolRemovalReason reason)
{
AssertLockHeld(cs);
Assume(!m_have_changeset);
auto descendants = m_txgraph->GetDescendants(*to_remove, TxGraph::Level::MAIN);
for (auto tx: descendants) {
removeUnchecked(mapTx.iterator_to(static_cast<const CTxMemPoolEntry&>(*tx)), reason);
}
}
void CTxMemPool::removeRecursive(const CTransaction &origTx, MemPoolRemovalReason reason)
{
// Remove transaction from memory pool
AssertLockHeld(cs);
Assume(!m_have_changeset);
setEntries txToRemove;
txiter origit = mapTx.find(origTx.GetHash());
if (origit != mapTx.end()) {
txToRemove.insert(origit);
} else {
// When recursively removing but origTx isn't in the mempool
// be sure to remove any children that are in the pool. This can
// happen during chain re-orgs if origTx isn't re-accepted into
// the mempool for any reason.
for (unsigned int i = 0; i < origTx.vout.size(); i++) {
auto it = mapNextTx.find(COutPoint(origTx.GetHash(), i));
if (it == mapNextTx.end())
continue;
txiter nextit = it->second;
assert(nextit != mapTx.end());
txToRemove.insert(nextit);
}
txiter origit = mapTx.find(origTx.GetHash());
if (origit != mapTx.end()) {
removeRecursive(origit, reason);
} else {
// When recursively removing but origTx isn't in the mempool
// be sure to remove any descendants that are in the pool. This can
// happen during chain re-orgs if origTx isn't re-accepted into
// the mempool for any reason.
auto iter = mapNextTx.lower_bound(COutPoint(origTx.GetHash(), 0));
std::vector<const TxGraph::Ref*> to_remove;
while (iter != mapNextTx.end() && iter->first->hash == origTx.GetHash()) {
to_remove.emplace_back(&*(iter->second));
++iter;
}
setEntries setAllRemoves;
for (txiter it : txToRemove) {
CalculateDescendants(it, setAllRemoves);
auto all_removes = m_txgraph->GetDescendantsUnion(to_remove, TxGraph::Level::MAIN);
for (auto ref : all_removes) {
auto tx = mapTx.iterator_to(static_cast<const CTxMemPoolEntry&>(*ref));
removeUnchecked(tx, reason);
}
RemoveStaged(setAllRemoves, reason);
}
}
void CTxMemPool::removeForReorg(CChain& chain, std::function<bool(txiter)> check_final_and_mature)
@@ -372,7 +378,7 @@ void CTxMemPool::removeConflicts(const CTransaction &tx)
if (Assume(txConflict.GetHash() != tx.GetHash()))
{
ClearPrioritisation(txConflict.GetHash());
removeRecursive(txConflict, MemPoolRemovalReason::CONFLICT);
removeRecursive(it->second, MemPoolRemovalReason::CONFLICT);
}
}
}