From f8bd13f85ae5404adef23a52719d804a5c36b1e8 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Fri, 23 Apr 2021 09:43:43 +0200 Subject: [PATCH 1/3] net: add new method Sock::Accept() that wraps accept() This will help to increase `Sock` usage and make more code mockable. --- src/test/fuzz/util.cpp | 16 ++++++++++++++++ src/test/fuzz/util.h | 2 ++ src/test/util/net.h | 18 ++++++++++++++++++ src/util/sock.cpp | 27 +++++++++++++++++++++++++++ src/util/sock.h | 9 +++++++++ 5 files changed, 72 insertions(+) diff --git a/src/test/fuzz/util.cpp b/src/test/fuzz/util.cpp index ae5f7a379e0..dae7920eff2 100644 --- a/src/test/fuzz/util.cpp +++ b/src/test/fuzz/util.cpp @@ -10,6 +10,8 @@ #include #include +#include + FuzzedSock::FuzzedSock(FuzzedDataProvider& fuzzed_data_provider) : m_fuzzed_data_provider{fuzzed_data_provider} { @@ -155,6 +157,20 @@ int FuzzedSock::Connect(const sockaddr*, socklen_t) const return 0; } +std::unique_ptr FuzzedSock::Accept(sockaddr* addr, socklen_t* addr_len) const +{ + constexpr std::array accept_errnos{ + ECONNABORTED, + EINTR, + ENOMEM, + }; + if (m_fuzzed_data_provider.ConsumeBool()) { + SetFuzzedErrNo(m_fuzzed_data_provider, accept_errnos); + return std::unique_ptr(); + } + return std::make_unique(m_fuzzed_data_provider); +} + int FuzzedSock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const { constexpr std::array getsockopt_errnos{ diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index 40aaeac63f2..fb42dcd0fc4 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -410,6 +410,8 @@ public: int Connect(const sockaddr*, socklen_t) const override; + std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const override; + int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override; bool Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred = nullptr) const override; diff --git a/src/test/util/net.h b/src/test/util/net.h index 2de6e712a25..3ef18400767 100644 --- a/src/test/util/net.h +++ b/src/test/util/net.h @@ -13,6 +13,7 @@ #include #include #include +#include #include struct ConnmanTestMsg : public CConnman { @@ -126,6 +127,23 @@ public: int Connect(const sockaddr*, socklen_t) const override { return 0; } + std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const override + { + if (addr != nullptr) { + // Pretend all connections come from 5.5.5.5:6789 + memset(addr, 0x00, *addr_len); + const socklen_t write_len = static_cast(sizeof(sockaddr_in)); + if (*addr_len >= write_len) { + *addr_len = write_len; + sockaddr_in* addr_in = reinterpret_cast(addr); + addr_in->sin_family = AF_INET; + memset(&addr_in->sin_addr, 0x05, sizeof(addr_in->sin_addr)); + addr_in->sin_port = htons(6789); + } + } + return std::make_unique(""); + }; + int GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const override { std::memset(opt_val, 0x0, *opt_len); diff --git a/src/util/sock.cpp b/src/util/sock.cpp index 1a4d67a65ee..2029d70a374 100644 --- a/src/util/sock.cpp +++ b/src/util/sock.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -73,6 +74,32 @@ int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const return connect(m_socket, addr, addr_len); } +std::unique_ptr Sock::Accept(sockaddr* addr, socklen_t* addr_len) const +{ +#ifdef WIN32 + static constexpr auto ERR = INVALID_SOCKET; +#else + static constexpr auto ERR = SOCKET_ERROR; +#endif + + std::unique_ptr sock; + + const auto socket = accept(m_socket, addr, addr_len); + if (socket != ERR) { + try { + sock = std::make_unique(socket); + } catch (const std::exception&) { +#ifdef WIN32 + closesocket(socket); +#else + close(socket); +#endif + } + } + + return sock; +} + int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const { return getsockopt(m_socket, level, opt_name, static_cast(opt_val), opt_len); diff --git a/src/util/sock.h b/src/util/sock.h index 59cc8c0b1df..75104828578 100644 --- a/src/util/sock.h +++ b/src/util/sock.h @@ -10,6 +10,7 @@ #include #include +#include #include /** @@ -96,6 +97,14 @@ public: */ [[nodiscard]] virtual int Connect(const sockaddr* addr, socklen_t addr_len) const; + /** + * accept(2) wrapper. Equivalent to `std::make_unique(accept(this->Get(), addr, addr_len))`. + * Code that uses this wrapper can be unit tested if this method is overridden by a mock Sock + * implementation. + * The returned unique_ptr is empty if `accept()` failed in which case errno will be set. + */ + [[nodiscard]] virtual std::unique_ptr Accept(sockaddr* addr, socklen_t* addr_len) const; + /** * getsockopt(2) wrapper. Equivalent to * `getsockopt(this->Get(), level, opt_name, opt_val, opt_len)`. Code that uses this From 9e3cbfca7c9efa620c0cce73503772805cc1fa82 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Fri, 23 Apr 2021 12:15:15 +0200 Subject: [PATCH 2/3] net: use Sock in CConnman::ListenSocket Change `CConnman::ListenSocket` to use a pointer to `Sock` instead of a bare `SOCKET` and use `Sock::Accept()` instead of bare `accept()`. This will help mocking / testing / fuzzing more code. --- src/net.cpp | 23 +++++++---------------- src/net.h | 9 +++++++-- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index e17db7e54b9..2ed9d82e251 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1098,10 +1098,10 @@ bool CConnman::AttemptToEvictConnection() void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { struct sockaddr_storage sockaddr; socklen_t len = sizeof(sockaddr); - SOCKET hSocket = accept(hListenSocket.socket, (struct sockaddr*)&sockaddr, &len); + auto sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len); CAddress addr; - if (hSocket == INVALID_SOCKET) { + if (!sock) { const int nErr = WSAGetLastError(); if (nErr != WSAEWOULDBLOCK) { LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr)); @@ -1115,12 +1115,12 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE}; } - const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(hSocket)), NODE_NONE}; + const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(sock->Get())), NODE_NONE}; NetPermissionFlags permissionFlags = NetPermissionFlags::None; hListenSocket.AddSocketPermissionFlags(permissionFlags); - CreateNodeFromAcceptedSocket(hSocket, permissionFlags, addr_bind, addr); + CreateNodeFromAcceptedSocket(sock->Release(), permissionFlags, addr_bind, addr); } void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, @@ -1359,7 +1359,7 @@ bool CConnman::GenerateSelectSet(const std::vector& nodes, std::set& error_set) { for (const ListenSocket& hListenSocket : vhListenSocket) { - recv_set.insert(hListenSocket.socket); + recv_set.insert(hListenSocket.sock->Get()); } for (CNode* pnode : nodes) { @@ -1640,7 +1640,7 @@ void CConnman::SocketHandlerListening(const std::set& recv_set) if (interruptNet) { return; } - if (recv_set.count(listen_socket.socket) > 0) { + if (recv_set.count(listen_socket.sock->Get()) > 0) { AcceptConnection(listen_socket); } } @@ -2391,7 +2391,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, return false; } - vhListenSocket.push_back(ListenSocket(sock->Release(), permissions)); + vhListenSocket.emplace_back(std::move(sock), permissions); return true; } @@ -2700,15 +2700,6 @@ void CConnman::StopNodes() DeleteNode(pnode); } - // Close listening sockets. - for (ListenSocket& hListenSocket : vhListenSocket) { - if (hListenSocket.socket != INVALID_SOCKET) { - if (!CloseSocket(hListenSocket.socket)) { - LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError())); - } - } - } - for (CNode* pnode : m_nodes_disconnected) { DeleteNode(pnode); } diff --git a/src/net.h b/src/net.h index dd5cc66a048..bb6e7a79ad5 100644 --- a/src/net.h +++ b/src/net.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -947,9 +948,13 @@ public: private: struct ListenSocket { public: - SOCKET socket; + std::shared_ptr sock; inline void AddSocketPermissionFlags(NetPermissionFlags& flags) const { NetPermissions::AddFlag(flags, m_permissions); } - ListenSocket(SOCKET socket_, NetPermissionFlags permissions_) : socket(socket_), m_permissions(permissions_) {} + ListenSocket(std::shared_ptr sock_, NetPermissionFlags permissions_) + : sock{sock_}, m_permissions{permissions_} + { + } + private: NetPermissionFlags m_permissions; }; From 6bf6e9fd9dece67878595a5f3361851c25833c49 Mon Sep 17 00:00:00 2001 From: Vasil Dimov Date: Tue, 13 Apr 2021 12:14:57 +0200 Subject: [PATCH 3/3] net: change CreateNodeFromAcceptedSocket() to take Sock Change `CConnman::CreateNodeFromAcceptedSocket()` to take a `Sock` argument instead of `SOCKET`. This makes the method mockable and also a little bit shorter as some `CloseSocket()` calls are removed (the socket will be closed automatically by the `Sock` destructor on early return). --- src/net.cpp | 17 ++++++----------- src/net.h | 4 ++-- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 2ed9d82e251..7f248e44d41 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -1120,10 +1120,10 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { NetPermissionFlags permissionFlags = NetPermissionFlags::None; hListenSocket.AddSocketPermissionFlags(permissionFlags); - CreateNodeFromAcceptedSocket(sock->Release(), permissionFlags, addr_bind, addr); + CreateNodeFromAcceptedSocket(std::move(sock), permissionFlags, addr_bind, addr); } -void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, +void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, NetPermissionFlags permissionFlags, const CAddress& addr_bind, const CAddress& addr) @@ -1149,27 +1149,24 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, if (!fNetworkActive) { LogPrint(BCLog::NET, "connection from %s dropped: not accepting new connections\n", addr.ToString()); - CloseSocket(hSocket); return; } - if (!IsSelectableSocket(hSocket)) + if (!IsSelectableSocket(sock->Get())) { LogPrintf("connection from %s dropped: non-selectable socket\n", addr.ToString()); - CloseSocket(hSocket); return; } // According to the internet TCP_NODELAY is not carried into accepted sockets // on all platforms. Set it again here just to be sure. - SetSocketNoDelay(hSocket); + SetSocketNoDelay(sock->Get()); // Don't accept connections from banned peers. bool banned = m_banman && m_banman->IsBanned(addr); if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && banned) { LogPrint(BCLog::NET, "connection from %s dropped (banned)\n", addr.ToString()); - CloseSocket(hSocket); return; } @@ -1178,7 +1175,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::NoBan) && nInbound + 1 >= nMaxInbound && discouraged) { LogPrint(BCLog::NET, "connection from %s dropped (discouraged)\n", addr.ToString()); - CloseSocket(hSocket); return; } @@ -1187,7 +1183,6 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, if (!AttemptToEvictConnection()) { // No connection to evict, disconnect the new connection LogPrint(BCLog::NET, "failed to find an eviction candidate - connection dropped (full)\n"); - CloseSocket(hSocket); return; } } @@ -1201,7 +1196,7 @@ void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, } const bool inbound_onion = std::find(m_onion_binds.begin(), m_onion_binds.end(), addr_bind) != m_onion_binds.end(); - CNode* pnode = new CNode(id, nodeServices, hSocket, addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion); + CNode* pnode = new CNode(id, nodeServices, sock->Release(), addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion); pnode->AddRef(); pnode->m_permissionFlags = permissionFlags; pnode->m_prefer_evict = discouraged; @@ -2329,7 +2324,7 @@ void CConnman::ThreadI2PAcceptIncoming() continue; } - CreateNodeFromAcceptedSocket(conn.sock->Release(), NetPermissionFlags::None, + CreateNodeFromAcceptedSocket(std::move(conn.sock), NetPermissionFlags::None, CAddress{conn.me, NODE_NONE}, CAddress{conn.peer, NODE_NONE}); } } diff --git a/src/net.h b/src/net.h index bb6e7a79ad5..13ebb6ac3a2 100644 --- a/src/net.h +++ b/src/net.h @@ -974,12 +974,12 @@ private: /** * Create a `CNode` object from a socket that has just been accepted and add the node to * the `m_nodes` member. - * @param[in] hSocket Connected socket to communicate with the peer. + * @param[in] sock Connected socket to communicate with the peer. * @param[in] permissionFlags The peer's permissions. * @param[in] addr_bind The address and port at our side of the connection. * @param[in] addr The address and port at the peer's side of the connection. */ - void CreateNodeFromAcceptedSocket(SOCKET hSocket, + void CreateNodeFromAcceptedSocket(std::unique_ptr&& sock, NetPermissionFlags permissionFlags, const CAddress& addr_bind, const CAddress& addr);