diff --git a/src/test/fuzz/p2p_transport_serialization.cpp b/src/test/fuzz/p2p_transport_serialization.cpp index 2618a2a3986..c29d8e70396 100644 --- a/src/test/fuzz/p2p_transport_serialization.cpp +++ b/src/test/fuzz/p2p_transport_serialization.cpp @@ -3,9 +3,11 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include +#include #include #include #include +#include #include #include @@ -24,7 +26,25 @@ FUZZ_TARGET_INIT(p2p_transport_serialization, initialize_p2p_transport_serializa // Construct deserializer, with a dummy NodeId V1TransportDeserializer deserializer{Params(), (NodeId)0, SER_NETWORK, INIT_PROTO_VERSION}; V1TransportSerializer serializer{}; - Span msg_bytes{buffer}; + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + + auto checksum_assist = fuzzed_data_provider.ConsumeBool(); + int header_random_bytes_count = checksum_assist ? CMessageHeader::CHECKSUM_OFFSET : CMessageHeader :: HEADER_SIZE; + auto mutable_msg_bytes = fuzzed_data_provider.ConsumeBytes(header_random_bytes_count); + auto payload_bytes = fuzzed_data_provider.ConsumeRemainingBytes(); + + if (checksum_assist && mutable_msg_bytes.size() == CMessageHeader::CHECKSUM_OFFSET) { + CHash256 hasher; + unsigned char hsh[32]; + hasher.Write(payload_bytes); + hasher.Finalize(hsh); + for (size_t i = 0; i < CMessageHeader::CHECKSUM_SIZE; ++i) { + mutable_msg_bytes.push_back(hsh[i]); + } + } + + mutable_msg_bytes.insert(mutable_msg_bytes.end(), payload_bytes.begin(), payload_bytes.end()); + Span msg_bytes{mutable_msg_bytes}; while (msg_bytes.size() > 0) { const int handled = deserializer.Read(msg_bytes); if (handled < 0) { @@ -36,7 +56,7 @@ FUZZ_TARGET_INIT(p2p_transport_serialization, initialize_p2p_transport_serializa std::optional result{deserializer.GetMessage(m_time, out_err_raw_size)}; if (result) { assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE); - assert(result->m_raw_message_size <= buffer.size()); + assert(result->m_raw_message_size <= mutable_msg_bytes.size()); assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size); assert(result->m_time == m_time);