diff --git a/src/dbwrapper.cpp b/src/dbwrapper.cpp index 3c574f89391..3743d434820 100644 --- a/src/dbwrapper.cpp +++ b/src/dbwrapper.cpp @@ -163,6 +163,8 @@ CDBBatch::CDBBatch(const CDBWrapper& _parent) : parent{_parent}, m_impl_batch{std::make_unique()} { + m_key_scratch.reserve(DBWRAPPER_PREALLOC_KEY_SIZE); + m_value_scratch.reserve(DBWRAPPER_PREALLOC_VALUE_SIZE); Clear(); }; @@ -171,13 +173,15 @@ CDBBatch::~CDBBatch() = default; void CDBBatch::Clear() { m_impl_batch->batch.Clear(); + assert(m_key_scratch.empty()); + assert(m_value_scratch.empty()); } -void CDBBatch::WriteImpl(std::span key, DataStream& ssValue) +void CDBBatch::WriteImpl(std::span key, DataStream& value) { leveldb::Slice slKey(CharCast(key.data()), key.size()); - dbwrapper_private::GetObfuscation(parent)(ssValue); - leveldb::Slice slValue(CharCast(ssValue.data()), ssValue.size()); + dbwrapper_private::GetObfuscation(parent)(value); + leveldb::Slice slValue(CharCast(value.data()), value.size()); m_impl_batch->batch.Put(slKey, slValue); } @@ -356,7 +360,10 @@ struct CDBIterator::IteratorImpl { }; CDBIterator::CDBIterator(const CDBWrapper& _parent, std::unique_ptr _piter) : parent(_parent), - m_impl_iter(std::move(_piter)) {} + m_impl_iter(std::move(_piter)) +{ + m_scratch.reserve(DBWRAPPER_PREALLOC_KEY_SIZE); +} CDBIterator* CDBWrapper::NewIterator() { diff --git a/src/dbwrapper.h b/src/dbwrapper.h index 4710af16d9e..51775ee6365 100644 --- a/src/dbwrapper.h +++ b/src/dbwrapper.h @@ -79,10 +79,10 @@ private: struct WriteBatchImpl; const std::unique_ptr m_impl_batch; - DataStream ssKey{}; - DataStream ssValue{}; + DataStream m_key_scratch{}; + DataStream m_value_scratch{}; - void WriteImpl(std::span key, DataStream& ssValue); + void WriteImpl(std::span key, DataStream& value); void EraseImpl(std::span key); public: @@ -96,22 +96,18 @@ public: template void Write(const K& key, const V& value) { - ssKey.reserve(DBWRAPPER_PREALLOC_KEY_SIZE); - ssValue.reserve(DBWRAPPER_PREALLOC_VALUE_SIZE); - ssKey << key; - ssValue << value; - WriteImpl(ssKey, ssValue); - ssKey.clear(); - ssValue.clear(); + ScopedDataStreamUsage scoped_key{m_key_scratch}, scoped_value{m_value_scratch}; + m_key_scratch << key; + m_value_scratch << value; + WriteImpl(m_key_scratch, m_value_scratch); } template void Erase(const K& key) { - ssKey.reserve(DBWRAPPER_PREALLOC_KEY_SIZE); - ssKey << key; - EraseImpl(ssKey); - ssKey.clear(); + ScopedDataStreamUsage scoped_key{m_key_scratch}; + m_key_scratch << key; + EraseImpl(m_key_scratch); } size_t ApproximateSize() const; @@ -125,6 +121,7 @@ public: private: const CDBWrapper &parent; const std::unique_ptr m_impl_iter; + DataStream m_scratch{}; void SeekImpl(std::span key); std::span GetKeyImpl() const; @@ -144,10 +141,9 @@ public: void SeekToFirst(); template void Seek(const K& key) { - DataStream ssKey{}; - ssKey.reserve(DBWRAPPER_PREALLOC_KEY_SIZE); - ssKey << key; - SeekImpl(ssKey); + ScopedDataStreamUsage scoped_scratch{m_scratch}; + m_scratch << key; + SeekImpl(m_scratch); } void Next(); @@ -164,9 +160,10 @@ public: template bool GetValue(V& value) { try { - DataStream ssValue{GetValueImpl()}; - dbwrapper_private::GetObfuscation(parent)(ssValue); - ssValue >> value; + ScopedDataStreamUsage scoped_scratch{m_scratch}; + m_scratch.write(GetValueImpl()); + dbwrapper_private::GetObfuscation(parent)(m_scratch); + m_scratch >> value; } catch (const std::exception&) { return false; } diff --git a/src/streams.h b/src/streams.h index 96cea55eb49..e939dd189f7 100644 --- a/src/streams.h +++ b/src/streams.h @@ -265,6 +265,20 @@ public: size_t GetMemoryUsage() const noexcept; }; +// Require empty scratch streams on entry and reset them on exit. +class ScopedDataStreamUsage +{ + DataStream& m_stream; + +public: + explicit ScopedDataStreamUsage(DataStream& stream) : m_stream{stream} { assert(m_stream.empty()); } + + ScopedDataStreamUsage(const ScopedDataStreamUsage&) = delete; + ScopedDataStreamUsage& operator=(const ScopedDataStreamUsage&) = delete; + + ~ScopedDataStreamUsage() { m_stream.clear(); } +}; + template class BitStreamReader { diff --git a/src/test/dbwrapper_tests.cpp b/src/test/dbwrapper_tests.cpp index a57b77e3aeb..185bf491e57 100644 --- a/src/test/dbwrapper_tests.cpp +++ b/src/test/dbwrapper_tests.cpp @@ -184,6 +184,13 @@ BOOST_AUTO_TEST_CASE(dbwrapper_batch) // key3 should've never been written BOOST_CHECK(dbw.Read(key3, res) == false); + + batch.Clear(); + batch.Write(key3, in3); + dbw.WriteBatch(batch); + + BOOST_CHECK(dbw.Read(key3, res)); + BOOST_CHECK_EQUAL(res.ToString(), in3.ToString()); } } @@ -212,18 +219,36 @@ BOOST_AUTO_TEST_CASE(dbwrapper_iterator) BOOST_CHECK(!it->GetKey(key_too_large)); uint8_t key_res; - uint256 val_res; BOOST_REQUIRE(it->GetKey(key_res)); - BOOST_REQUIRE(it->GetValue(val_res)); BOOST_CHECK_EQUAL(key_res, key); + // A failed value decode must not leave the iterator's scratch stream dirty. + std::pair value_too_large; + BOOST_CHECK(!it->GetValue(value_too_large)); + + uint256 val_res; + BOOST_REQUIRE(it->GetValue(val_res)); + BOOST_CHECK_EQUAL(val_res.ToString(), in.ToString()); + + it->Seek(key2); + + BOOST_REQUIRE(it->GetKey(key_res)); + BOOST_CHECK_EQUAL(key_res, key2); + BOOST_REQUIRE(it->GetValue(val_res)); + BOOST_CHECK_EQUAL(val_res.ToString(), in2.ToString()); + + it->Seek(key); + + BOOST_REQUIRE(it->GetKey(key_res)); + BOOST_CHECK_EQUAL(key_res, key); + BOOST_REQUIRE(it->GetValue(val_res)); BOOST_CHECK_EQUAL(val_res.ToString(), in.ToString()); it->Next(); BOOST_REQUIRE(it->GetKey(key_res)); - BOOST_REQUIRE(it->GetValue(val_res)); BOOST_CHECK_EQUAL(key_res, key2); + BOOST_REQUIRE(it->GetValue(val_res)); BOOST_CHECK_EQUAL(val_res.ToString(), in2.ToString()); it->Next(); diff --git a/src/test/streams_tests.cpp b/src/test/streams_tests.cpp index 6a6026bf410..6a3be3cce86 100644 --- a/src/test/streams_tests.cpp +++ b/src/test/streams_tests.cpp @@ -88,6 +88,24 @@ BOOST_AUTO_TEST_CASE(obfuscation_empty) BOOST_CHECK(non_null_obf); } +BOOST_AUTO_TEST_CASE(streams_scoped_data_stream_usage) +{ + DataStream stream{}; + { + ScopedDataStreamUsage usage{stream}; + stream << uint8_t{42}; + BOOST_CHECK_GT(stream.size(), 0U); + } + BOOST_CHECK(stream.empty()); + + { + ScopedDataStreamUsage usage{stream}; + stream << uint16_t{42}; + BOOST_CHECK_GT(stream.size(), 0U); + } + BOOST_CHECK(stream.empty()); +} + BOOST_AUTO_TEST_CASE(xor_file) { fs::path xor_path{m_args.GetDataDirBase() / "test_xor.bin"};