net: use Sock in CConnman::ListenSocket

Change `CConnman::ListenSocket` to use a pointer to `Sock` instead of a
bare `SOCKET` and use `Sock::Accept()` instead of bare `accept()`. This
will help mocking / testing / fuzzing more code.
This commit is contained in:
Vasil Dimov
2021-04-23 12:15:15 +02:00
parent f8bd13f85a
commit 9e3cbfca7c
2 changed files with 14 additions and 18 deletions

View File

@@ -1098,10 +1098,10 @@ bool CConnman::AttemptToEvictConnection()
void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { void CConnman::AcceptConnection(const ListenSocket& hListenSocket) {
struct sockaddr_storage sockaddr; struct sockaddr_storage sockaddr;
socklen_t len = sizeof(sockaddr); socklen_t len = sizeof(sockaddr);
SOCKET hSocket = accept(hListenSocket.socket, (struct sockaddr*)&sockaddr, &len); auto sock = hListenSocket.sock->Accept((struct sockaddr*)&sockaddr, &len);
CAddress addr; CAddress addr;
if (hSocket == INVALID_SOCKET) { if (!sock) {
const int nErr = WSAGetLastError(); const int nErr = WSAGetLastError();
if (nErr != WSAEWOULDBLOCK) { if (nErr != WSAEWOULDBLOCK) {
LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr)); LogPrintf("socket error accept failed: %s\n", NetworkErrorString(nErr));
@@ -1115,12 +1115,12 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) {
addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE}; addr = CAddress{MaybeFlipIPv6toCJDNS(addr), NODE_NONE};
} }
const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(hSocket)), NODE_NONE}; const CAddress addr_bind{MaybeFlipIPv6toCJDNS(GetBindAddress(sock->Get())), NODE_NONE};
NetPermissionFlags permissionFlags = NetPermissionFlags::None; NetPermissionFlags permissionFlags = NetPermissionFlags::None;
hListenSocket.AddSocketPermissionFlags(permissionFlags); hListenSocket.AddSocketPermissionFlags(permissionFlags);
CreateNodeFromAcceptedSocket(hSocket, permissionFlags, addr_bind, addr); CreateNodeFromAcceptedSocket(sock->Release(), permissionFlags, addr_bind, addr);
} }
void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket, void CConnman::CreateNodeFromAcceptedSocket(SOCKET hSocket,
@@ -1359,7 +1359,7 @@ bool CConnman::GenerateSelectSet(const std::vector<CNode*>& nodes,
std::set<SOCKET>& error_set) std::set<SOCKET>& error_set)
{ {
for (const ListenSocket& hListenSocket : vhListenSocket) { for (const ListenSocket& hListenSocket : vhListenSocket) {
recv_set.insert(hListenSocket.socket); recv_set.insert(hListenSocket.sock->Get());
} }
for (CNode* pnode : nodes) { for (CNode* pnode : nodes) {
@@ -1640,7 +1640,7 @@ void CConnman::SocketHandlerListening(const std::set<SOCKET>& recv_set)
if (interruptNet) { if (interruptNet) {
return; return;
} }
if (recv_set.count(listen_socket.socket) > 0) { if (recv_set.count(listen_socket.sock->Get()) > 0) {
AcceptConnection(listen_socket); AcceptConnection(listen_socket);
} }
} }
@@ -2391,7 +2391,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
return false; return false;
} }
vhListenSocket.push_back(ListenSocket(sock->Release(), permissions)); vhListenSocket.emplace_back(std::move(sock), permissions);
return true; return true;
} }
@@ -2700,15 +2700,6 @@ void CConnman::StopNodes()
DeleteNode(pnode); DeleteNode(pnode);
} }
// Close listening sockets.
for (ListenSocket& hListenSocket : vhListenSocket) {
if (hListenSocket.socket != INVALID_SOCKET) {
if (!CloseSocket(hListenSocket.socket)) {
LogPrintf("CloseSocket(hListenSocket) failed with error %s\n", NetworkErrorString(WSAGetLastError()));
}
}
}
for (CNode* pnode : m_nodes_disconnected) { for (CNode* pnode : m_nodes_disconnected) {
DeleteNode(pnode); DeleteNode(pnode);
} }

View File

@@ -26,6 +26,7 @@
#include <threadinterrupt.h> #include <threadinterrupt.h>
#include <uint256.h> #include <uint256.h>
#include <util/check.h> #include <util/check.h>
#include <util/sock.h>
#include <atomic> #include <atomic>
#include <condition_variable> #include <condition_variable>
@@ -947,9 +948,13 @@ public:
private: private:
struct ListenSocket { struct ListenSocket {
public: public:
SOCKET socket; std::shared_ptr<Sock> sock;
inline void AddSocketPermissionFlags(NetPermissionFlags& flags) const { NetPermissions::AddFlag(flags, m_permissions); } inline void AddSocketPermissionFlags(NetPermissionFlags& flags) const { NetPermissions::AddFlag(flags, m_permissions); }
ListenSocket(SOCKET socket_, NetPermissionFlags permissions_) : socket(socket_), m_permissions(permissions_) {} ListenSocket(std::shared_ptr<Sock> sock_, NetPermissionFlags permissions_)
: sock{sock_}, m_permissions{permissions_}
{
}
private: private:
NetPermissionFlags m_permissions; NetPermissionFlags m_permissions;
}; };