Move checksum check from net_processing to net

This removes the m_valid_checksum member from CNetMessage.  Instead,
GetMessage() returns an Optional.

Additionally, GetMessage() has been given an out parameter to be used to
hold error information.  For now it is specifically a uint32_t used to
hold the raw size of the corrupt message.

The checksum check is now done in GetMessage.
This commit is contained in:
Troy Giorshev
2020-06-29 14:15:06 -04:00
parent 2716647ebf
commit 890b1d7c2b
4 changed files with 45 additions and 44 deletions

View File

@@ -595,25 +595,33 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
while (nBytes > 0) { while (nBytes > 0) {
// absorb network data // absorb network data
int handled = m_deserializer->Read(pch, nBytes); int handled = m_deserializer->Read(pch, nBytes);
if (handled < 0) return false; if (handled < 0) {
return false;
}
pch += handled; pch += handled;
nBytes -= handled; nBytes -= handled;
if (m_deserializer->Complete()) { if (m_deserializer->Complete()) {
// decompose a transport agnostic CNetMessage from the deserializer // decompose a transport agnostic CNetMessage from the deserializer
CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), time); uint32_t out_err_raw_size{0};
Optional<CNetMessage> result{m_deserializer->GetMessage(Params().MessageStart(), time, out_err_raw_size)};
if (!result) {
// store the size of the corrupt message
mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size;
continue;
}
//store received bytes per message command //store received bytes per message command
//to prevent a memory DOS, only allow valid commands //to prevent a memory DOS, only allow valid commands
mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(msg.m_command); mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find(result->m_command);
if (i == mapRecvBytesPerMsgCmd.end()) if (i == mapRecvBytesPerMsgCmd.end())
i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER); i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER);
assert(i != mapRecvBytesPerMsgCmd.end()); assert(i != mapRecvBytesPerMsgCmd.end());
i->second += msg.m_raw_message_size; i->second += result->m_raw_message_size;
// push the message to the process queue, // push the message to the process queue,
vRecvMsg.push_back(std::move(msg)); vRecvMsg.push_back(std::move(*result));
complete = true; complete = true;
} }
@@ -679,37 +687,36 @@ const uint256& V1TransportDeserializer::GetMessageHash() const
return data_hash; return data_hash;
} }
CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time) Optional<CNetMessage> V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time, uint32_t& out_err_raw_size)
{ {
// decompose a single CNetMessage from the TransportDeserializer // decompose a single CNetMessage from the TransportDeserializer
CNetMessage msg(std::move(vRecv)); Optional<CNetMessage> msg(std::move(vRecv));
// store state about valid header, netmagic and checksum // store state about valid header, netmagic and checksum
msg.m_valid_header = hdr.IsValid(message_start); msg->m_valid_header = hdr.IsValid(message_start);
msg.m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0); msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, message_start, CMessageHeader::MESSAGE_START_SIZE) == 0);
uint256 hash = GetMessageHash(); uint256 hash = GetMessageHash();
// store command string, payload size // store command string, time, and sizes
msg.m_command = hdr.GetCommand(); msg->m_command = hdr.GetCommand();
msg.m_message_size = hdr.nMessageSize; msg->m_time = time;
msg.m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE; msg->m_message_size = hdr.nMessageSize;
msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE;
// We just received a message off the wire, harvest entropy from the time (and the message checksum) // We just received a message off the wire, harvest entropy from the time (and the message checksum)
RandAddEvent(ReadLE32(hash.begin())); RandAddEvent(ReadLE32(hash.begin()));
msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) == 0); if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0) {
if (!msg.m_valid_checksum) {
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n", LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n",
SanitizeString(msg.m_command), msg.m_message_size, SanitizeString(msg->m_command), msg->m_message_size,
HexStr(Span<uint8_t>(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)), HexStr(Span<uint8_t>(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)),
HexStr(hdr.pchChecksum), HexStr(hdr.pchChecksum),
m_node_id); m_node_id);
out_err_raw_size = msg->m_raw_message_size;
msg = nullopt;
} }
// store receive time // Always reset the network deserializer (prepare for the next message)
msg.m_time = time;
// reset the network deserializer (prepare for the next message)
Reset(); Reset();
return msg; return msg;
} }

View File

@@ -14,8 +14,9 @@
#include <crypto/siphash.h> #include <crypto/siphash.h>
#include <hash.h> #include <hash.h>
#include <limitedmap.h> #include <limitedmap.h>
#include <netaddress.h>
#include <net_permissions.h> #include <net_permissions.h>
#include <netaddress.h>
#include <optional.h>
#include <policy/feerate.h> #include <policy/feerate.h>
#include <protocol.h> #include <protocol.h>
#include <random.h> #include <random.h>
@@ -706,7 +707,6 @@ public:
std::chrono::microseconds m_time{0}; //!< time of message receipt std::chrono::microseconds m_time{0}; //!< time of message receipt
bool m_valid_netmagic = false; bool m_valid_netmagic = false;
bool m_valid_header = false; bool m_valid_header = false;
bool m_valid_checksum = false;
uint32_t m_message_size{0}; //!< size of the payload uint32_t m_message_size{0}; //!< size of the payload
uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
std::string m_command; std::string m_command;
@@ -732,7 +732,7 @@ public:
// read and deserialize data // read and deserialize data
virtual int Read(const char *data, unsigned int bytes) = 0; virtual int Read(const char *data, unsigned int bytes) = 0;
// decomposes a message from the context // decomposes a message from the context
virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) = 0; virtual Optional<CNetMessage> GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err) = 0;
virtual ~TransportDeserializer() {} virtual ~TransportDeserializer() {}
}; };
@@ -790,7 +790,7 @@ public:
if (ret < 0) Reset(); if (ret < 0) Reset();
return ret; return ret;
} }
CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) override; Optional<CNetMessage> GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time, uint32_t& out_err_raw_size) override;
}; };
/** The TransportSerializer prepares messages for the network transport /** The TransportSerializer prepares messages for the network transport

View File

@@ -3886,17 +3886,8 @@ bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgP
// Message size // Message size
unsigned int nMessageSize = msg.m_message_size; unsigned int nMessageSize = msg.m_message_size;
// Checksum
CDataStream& vRecv = msg.m_recv;
if (!msg.m_valid_checksum)
{
LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR peer=%d\n", __func__,
SanitizeString(msg_type), nMessageSize, pfrom->GetId());
return fMoreWork;
}
try { try {
ProcessMessage(*pfrom, msg_type, vRecv, msg.m_time, interruptMsgProc); ProcessMessage(*pfrom, msg_type, msg.m_recv, msg.m_time, interruptMsgProc);
if (interruptMsgProc) if (interruptMsgProc)
return false; return false;
if (!pfrom->vRecvGetData.empty()) if (!pfrom->vRecvGetData.empty())

View File

@@ -32,16 +32,19 @@ void test_one_input(const std::vector<uint8_t>& buffer)
n_bytes -= handled; n_bytes -= handled;
if (deserializer.Complete()) { if (deserializer.Complete()) {
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()}; const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
const CNetMessage msg = deserializer.GetMessage(Params().MessageStart(), m_time); uint32_t out_err_raw_size{0};
assert(msg.m_command.size() <= CMessageHeader::COMMAND_SIZE); Optional<CNetMessage> result{deserializer.GetMessage(Params().MessageStart(), m_time, out_err_raw_size)};
assert(msg.m_raw_message_size <= buffer.size()); if (result) {
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size); assert(result->m_command.size() <= CMessageHeader::COMMAND_SIZE);
assert(msg.m_time == m_time); assert(result->m_raw_message_size <= buffer.size());
if (msg.m_valid_header) { assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size);
assert(msg.m_valid_netmagic); assert(result->m_time == m_time);
if (result->m_valid_header) {
assert(result->m_valid_netmagic);
}
if (!result->m_valid_netmagic) {
assert(!result->m_valid_header);
} }
if (!msg.m_valid_netmagic) {
assert(!msg.m_valid_header);
} }
} }
} }