Merge #20788: net: add RAII socket and use it instead of bare SOCKET

615ba0eb96 test: add Sock unit tests (Vasil Dimov)
7bd21ce1ef style: rename hSocket to sock (Vasil Dimov)
04ae846904 net: use Sock in InterruptibleRecv() and Socks5() (Vasil Dimov)
ba9d73268f net: add RAII socket and use it instead of bare SOCKET (Vasil Dimov)
dec9b5e850 net: move CloseSocket() from netbase to util/sock (Vasil Dimov)
aa17a44551 net: move MillisToTimeval() from netbase to util/time (Vasil Dimov)

Pull request description:

  Introduce a class to manage the lifetime of a socket - when the object
  that contains the socket goes out of scope, the underlying socket will
  be closed.

  In addition, the new `Sock` class has a `Send()`, `Recv()` and `Wait()`
  methods that can be overridden by unit tests to mock the socket
  operations.

  The `Wait()` method also hides the
  `#ifdef USE_POLL poll() #else select() #endif` technique from higher
  level code.

ACKs for top commit:
  laanwj:
    Re-ACK 615ba0eb96
  jonatack:
    re-ACK 615ba0eb96

Tree-SHA512: 3003e6bc0259295ca0265ccdeb1522ee25b4abe66d32e6ceaa51b55e0a999df7ddee765f86ce558a788c1953ee2009bfa149b09d494593f7d799c0d7d930bee8
This commit is contained in:
Wladimir J. van der Laan
2021-02-11 13:21:31 +01:00
11 changed files with 527 additions and 147 deletions

View File

@@ -20,6 +20,7 @@
#include <protocol.h>
#include <random.h>
#include <scheduler.h>
#include <util/sock.h>
#include <util/strencodings.h>
#include <util/translation.h>
@@ -429,24 +430,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
// Connect
bool connected = false;
SOCKET hSocket = INVALID_SOCKET;
std::unique_ptr<Sock> sock;
proxyType proxy;
if (addrConnect.IsValid()) {
bool proxyConnectionFailed = false;
if (GetProxy(addrConnect.GetNetwork(), proxy)) {
hSocket = CreateSocket(proxy.proxy);
if (hSocket == INVALID_SOCKET) {
sock = CreateSock(proxy.proxy);
if (!sock) {
return nullptr;
}
connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(), hSocket, nConnectTimeout, proxyConnectionFailed);
connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(),
*sock, nConnectTimeout, proxyConnectionFailed);
} else {
// no proxy needed (none set for target network)
hSocket = CreateSocket(addrConnect);
if (hSocket == INVALID_SOCKET) {
sock = CreateSock(addrConnect);
if (!sock) {
return nullptr;
}
connected = ConnectSocketDirectly(addrConnect, hSocket, nConnectTimeout, conn_type == ConnectionType::MANUAL);
connected = ConnectSocketDirectly(addrConnect, sock->Get(), nConnectTimeout,
conn_type == ConnectionType::MANUAL);
}
if (!proxyConnectionFailed) {
// If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to
@@ -454,26 +457,26 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
addrman.Attempt(addrConnect, fCountFailure);
}
} else if (pszDest && GetNameProxy(proxy)) {
hSocket = CreateSocket(proxy.proxy);
if (hSocket == INVALID_SOCKET) {
sock = CreateSock(proxy.proxy);
if (!sock) {
return nullptr;
}
std::string host;
int port = default_port;
SplitHostPort(std::string(pszDest), port, host);
bool proxyConnectionFailed;
connected = ConnectThroughProxy(proxy, host, port, hSocket, nConnectTimeout, proxyConnectionFailed);
connected = ConnectThroughProxy(proxy, host, port, *sock, nConnectTimeout,
proxyConnectionFailed);
}
if (!connected) {
CloseSocket(hSocket);
return nullptr;
}
// Add node
NodeId id = GetNewNodeId();
uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize();
CAddress addr_bind = GetBindAddress(hSocket);
CNode* pnode = new CNode(id, nLocalServices, hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type);
CAddress addr_bind = GetBindAddress(sock->Get());
CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type);
pnode->AddRef();
// We're making a new connection, harvest entropy from the time (and our peer count)
@@ -2188,9 +2191,8 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
return false;
}
SOCKET hListenSocket = CreateSocket(addrBind);
if (hListenSocket == INVALID_SOCKET)
{
std::unique_ptr<Sock> sock = CreateSock(addrBind);
if (!sock) {
strError = strprintf(Untranslated("Error: Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError()));
LogPrintf("%s\n", strError.original);
return false;
@@ -2198,21 +2200,21 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
// Allow binding if the port is still in TIME_WAIT state after
// the program was closed and restarted.
setsockopt(hListenSocket, SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int));
setsockopt(sock->Get(), SOL_SOCKET, SO_REUSEADDR, (sockopt_arg_type)&nOne, sizeof(int));
// some systems don't have IPV6_V6ONLY but are always v6only; others do have the option
// and enable it by default or not. Try to enable it, if possible.
if (addrBind.IsIPv6()) {
#ifdef IPV6_V6ONLY
setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int));
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_V6ONLY, (sockopt_arg_type)&nOne, sizeof(int));
#endif
#ifdef WIN32
int nProtLevel = PROTECTION_LEVEL_UNRESTRICTED;
setsockopt(hListenSocket, IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int));
setsockopt(sock->Get(), IPPROTO_IPV6, IPV6_PROTECTION_LEVEL, (const char*)&nProtLevel, sizeof(int));
#endif
}
if (::bind(hListenSocket, (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR)
if (::bind(sock->Get(), (struct sockaddr*)&sockaddr, len) == SOCKET_ERROR)
{
int nErr = WSAGetLastError();
if (nErr == WSAEADDRINUSE)
@@ -2220,21 +2222,19 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
else
strError = strprintf(_("Unable to bind to %s on this computer (bind returned error %s)"), addrBind.ToString(), NetworkErrorString(nErr));
LogPrintf("%s\n", strError.original);
CloseSocket(hListenSocket);
return false;
}
LogPrintf("Bound to %s\n", addrBind.ToString());
// Listen for incoming connections
if (listen(hListenSocket, SOMAXCONN) == SOCKET_ERROR)
if (listen(sock->Get(), SOMAXCONN) == SOCKET_ERROR)
{
strError = strprintf(_("Error: Listening for incoming connections failed (listen returned error %s)"), NetworkErrorString(WSAGetLastError()));
LogPrintf("%s\n", strError.original);
CloseSocket(hListenSocket);
return false;
}
vhListenSocket.push_back(ListenSocket(hListenSocket, permissions));
vhListenSocket.push_back(ListenSocket(sock->Release(), permissions));
return true;
}