mirror of
https://github.com/bitcoin/bitcoin.git
synced 2025-12-18 16:42:54 +01:00
Move checksum check from net_processing to net
This removes the m_valid_checksum member from CNetMessage. Instead, GetMessage() returns an Optional. Additionally, GetMessage() has been given an out parameter to be used to hold error information. For now it is specifically a uint32_t used to hold the raw size of the corrupt message. The checksum check is now done in GetMessage.
This commit is contained in:
47
src/net.cpp
47
src/net.cpp
@@ -595,25 +595,33 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
|
|||||||
while (nBytes > 0) {
|
while (nBytes > 0) {
|
||||||
// absorb network data
|
// absorb network data
|
||||||
int handled = m_deserializer->Read(pch, nBytes);
|
int handled = m_deserializer->Read(pch, nBytes);
|
||||||
if (handled < 0) return false;
|
if (handled < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
pch += handled;
|
pch += handled;
|
||||||
nBytes -= handled;
|
nBytes -= handled;
|
||||||
|
|
||||||
if (m_deserializer->Complete()) {
|
if (m_deserializer->Complete()) {
|
||||||
// decompose a transport agnostic CNetMessage from the deserializer
|
// decompose a transport agnostic CNetMessage from the deserializer
|
||||||
CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), time);
|
uint32_t out_err_raw_size{0};
|
||||||
|
Optional<CNetMessage> result{m_deserializer->GetMessage(Params().MessageStart(), time, out_err_raw_size)};
|
||||||
|
if (!result) {
|
||||||
|
// store the size of the corrupt message
|
||||||
|
mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
//store received bytes per message command
|
//store received bytes per message command
|
||||||
//to prevent a memory DOS, only allow valid commands
|
//to prevent a memory DOS, only allow valid commands
|
||||||
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(msg.m_command);
|
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(result->m_command);
|
||||||
if (i == mapRecvBytesPerMsgCmd.end())
|
if (i == mapRecvBytesPerMsgCmd.end())
|
||||||
i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER);
|
i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER);
|
||||||
assert(i != mapRecvBytesPerMsgCmd.end());
|
assert(i != mapRecvBytesPerMsgCmd.end());
|
||||||
i->second += msg.m_raw_message_size;
|
i->second += result->m_raw_message_size;
|
||||||
|
|
||||||
// push the message to the process queue,
|
// push the message to the process queue,
|
||||||
vRecvMsg.push_back(std::move(msg));
|
vRecvMsg.push_back(std::move(*result));
|
||||||
|
|
||||||
complete = true;
|
complete = true;
|
||||||
}
|
}
|
||||||
@@ -679,37 +687,36 @@ const uint256& V1TransportDeserializer::GetMessageHash() const
|
|||||||
return data_hash;
|
return data_hash;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time)
|
Optional<CNetMessage> V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time, uint32_t& out_err_raw_size)
|
||||||
{
|
{
|
||||||
// decompose a single CNetMessage from the TransportDeserializer
|
// decompose a single CNetMessage from the TransportDeserializer
|
||||||
CNetMessage msg(std::move(vRecv));
|
Optional<CNetMessage> msg(std::move(vRecv));
|
||||||
|
|
||||||
// store state about valid header, netmagic and checksum
|
// store state about valid header, netmagic and checksum
|
||||||
msg.m_valid_header = hdr.IsValid(message_start);
|
msg->m_valid_header = hdr.IsValid(message_start);
|
||||||
msg.m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0);
|
msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0);
|
||||||
uint256 hash = GetMessageHash();
|
uint256 hash = GetMessageHash();
|
||||||
|
|
||||||
// store command string, payload size
|
// store command string, time, and sizes
|
||||||
msg.m_command = hdr.GetCommand();
|
msg->m_command = hdr.GetCommand();
|
||||||
msg.m_message_size = hdr.nMessageSize;
|
msg->m_time = time;
|
||||||
msg.m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
|
msg->m_message_size = hdr.nMessageSize;
|
||||||
|
msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
|
||||||
|
|
||||||
// We just received a message off the wire, harvest entropy from the time (and the message checksum)
|
// We just received a message off the wire, harvest entropy from the time (and the message checksum)
|
||||||
RandAddEvent(ReadLE32(hash.begin()));
|
RandAddEvent(ReadLE32(hash.begin()));
|
||||||
|
|
||||||
msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) == 0);
|
if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0) {
|
||||||
if (!msg.m_valid_checksum) {
|
|
||||||
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n",
|
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n",
|
||||||
SanitizeString(msg.m_command), msg.m_message_size,
|
SanitizeString(msg->m_command), msg->m_message_size,
|
||||||
HexStr(Span<uint8_t>(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)),
|
HexStr(Span<uint8_t>(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)),
|
||||||
HexStr(hdr.pchChecksum),
|
HexStr(hdr.pchChecksum),
|
||||||
m_node_id);
|
m_node_id);
|
||||||
|
out_err_raw_size = msg->m_raw_message_size;
|
||||||
|
msg = nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// store receive time
|
// Always reset the network deserializer (prepare for the next message)
|
||||||
msg.m_time = time;
|
|
||||||
|
|
||||||
// reset the network deserializer (prepare for the next message)
|
|
||||||
Reset();
|
Reset();
|
||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,9 @@
|
|||||||
#include <crypto/siphash.h>
|
#include <crypto/siphash.h>
|
||||||
#include <hash.h>
|
#include <hash.h>
|
||||||
#include <limitedmap.h>
|
#include <limitedmap.h>
|
||||||
#include <netaddress.h>
|
|
||||||
#include <net_permissions.h>
|
#include <net_permissions.h>
|
||||||
|
#include <netaddress.h>
|
||||||
|
#include <optional.h>
|
||||||
#include <policy/feerate.h>
|
#include <policy/feerate.h>
|
||||||
#include <protocol.h>
|
#include <protocol.h>
|
||||||
#include <random.h>
|
#include <random.h>
|
||||||
@@ -706,7 +707,6 @@ public:
|
|||||||
std::chrono::microseconds m_time{0}; //!< time of message receipt
|
std::chrono::microseconds m_time{0}; //!< time of message receipt
|
||||||
bool m_valid_netmagic = false;
|
bool m_valid_netmagic = false;
|
||||||
bool m_valid_header = false;
|
bool m_valid_header = false;
|
||||||
bool m_valid_checksum = false;
|
|
||||||
uint32_t m_message_size{0}; //!< size of the payload
|
uint32_t m_message_size{0}; //!< size of the payload
|
||||||
uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
|
uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
|
||||||
std::string m_command;
|
std::string m_command;
|
||||||
@@ -732,7 +732,7 @@ public:
|
|||||||
// read and deserialize data
|
// read and deserialize data
|
||||||
virtual int Read(const char *data, unsigned int bytes) = 0;
|
virtual int Read(const char *data, unsigned int bytes) = 0;
|
||||||
// decomposes a message from the context
|
// decomposes a message from the context
|
||||||
virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) = 0;
|
virtual Optional<CNetMessage> GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err) = 0;
|
||||||
virtual ~TransportDeserializer() {}
|
virtual ~TransportDeserializer() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -790,7 +790,7 @@ public:
|
|||||||
if (ret < 0) Reset();
|
if (ret < 0) Reset();
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) override;
|
Optional<CNetMessage> GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err_raw_size) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** The TransportSerializer prepares messages for the network transport
|
/** The TransportSerializer prepares messages for the network transport
|
||||||
|
|||||||
@@ -3886,17 +3886,8 @@ bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgP
|
|||||||
// Message size
|
// Message size
|
||||||
unsigned int nMessageSize = msg.m_message_size;
|
unsigned int nMessageSize = msg.m_message_size;
|
||||||
|
|
||||||
// Checksum
|
|
||||||
CDataStream& vRecv = msg.m_recv;
|
|
||||||
if (!msg.m_valid_checksum)
|
|
||||||
{
|
|
||||||
LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR peer=%d\n", __func__,
|
|
||||||
SanitizeString(msg_type), nMessageSize, pfrom->GetId());
|
|
||||||
return fMoreWork;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
ProcessMessage(*pfrom, msg_type, vRecv, msg.m_time, interruptMsgProc);
|
ProcessMessage(*pfrom, msg_type, msg.m_recv, msg.m_time, interruptMsgProc);
|
||||||
if (interruptMsgProc)
|
if (interruptMsgProc)
|
||||||
return false;
|
return false;
|
||||||
if (!pfrom->vRecvGetData.empty())
|
if (!pfrom->vRecvGetData.empty())
|
||||||
|
|||||||
@@ -32,16 +32,19 @@ void test_one_input(const std::vector<uint8_t>& buffer)
|
|||||||
n_bytes -= handled;
|
n_bytes -= handled;
|
||||||
if (deserializer.Complete()) {
|
if (deserializer.Complete()) {
|
||||||
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
|
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
|
||||||
const CNetMessage msg = deserializer.GetMessage(Params().MessageStart(), m_time);
|
uint32_t out_err_raw_size{0};
|
||||||
assert(msg.m_command.size() <= CMessageHeader::COMMAND_SIZE);
|
Optional<CNetMessage> result{deserializer.GetMessage(Params().MessageStart(), m_time, out_err_raw_size)};
|
||||||
assert(msg.m_raw_message_size <= buffer.size());
|
if (result) {
|
||||||
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size);
|
assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE);
|
||||||
assert(msg.m_time == m_time);
|
assert(result->m_raw_message_size <= buffer.size());
|
||||||
if (msg.m_valid_header) {
|
assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size);
|
||||||
assert(msg.m_valid_netmagic);
|
assert(result->m_time == m_time);
|
||||||
|
if (result->m_valid_header) {
|
||||||
|
assert(result->m_valid_netmagic);
|
||||||
|
}
|
||||||
|
if (!result->m_valid_netmagic) {
|
||||||
|
assert(!result->m_valid_header);
|
||||||
}
|
}
|
||||||
if (!msg.m_valid_netmagic) {
|
|
||||||
assert(!msg.m_valid_header);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user