diff --git a/test/functional/test_framework/p2p.py b/test/functional/test_framework/p2p.py index 4ababe01dc1..aa9c093b650 100755 --- a/test/functional/test_framework/p2p.py +++ b/test/functional/test_framework/p2p.py @@ -168,6 +168,7 @@ class P2PConnection(asyncio.Protocol): # p2p_lock must not be acquired after _send_lock as it could result in deadlocks. self._send_lock = threading.Lock() self.v2_state = None # EncryptedP2PState object needed for v2 p2p connections + self.reconnect = False # set if reconnection needs to happen @property def is_connected(self): @@ -197,8 +198,9 @@ class P2PConnection(asyncio.Protocol): coroutine = loop.create_connection(lambda: self, host=self.dstaddr, port=self.dstport) return lambda: loop.call_soon_threadsafe(loop.create_task, coroutine) - def peer_accept_connection(self, connect_id, connect_cb=lambda: None, *, net, timeout_factor, supports_v2_p2p): + def peer_accept_connection(self, connect_id, connect_cb=lambda: None, *, net, timeout_factor, supports_v2_p2p, reconnect): self.peer_connect_helper('0', 0, net, timeout_factor) + self.reconnect = reconnect if supports_v2_p2p: self.v2_state = EncryptedP2PState(initiating=False, net=net) @@ -222,14 +224,16 @@ class P2PConnection(asyncio.Protocol): send_handshake_bytes = self.v2_state.initiate_v2_handshake() self.send_raw_message(send_handshake_bytes) # if v2 connection, send `on_connection_send_msg` after initial v2 handshake. - if self.on_connection_send_msg and not self.supports_v2_p2p: + # if reconnection situation, send `on_connection_send_msg` after version message is received in `on_version()`. + if self.on_connection_send_msg and not self.supports_v2_p2p and not self.reconnect: self.send_message(self.on_connection_send_msg) self.on_connection_send_msg = None # Never used again self.on_open() def connection_lost(self, exc): """asyncio callback when a connection is closed.""" - if exc: + # don't display warning if reconnection needs to be attempted using v1 P2P + if exc and not self.reconnect: logger.warning("Connection lost to {}:{} due to {}".format(self.dstaddr, self.dstport, exc)) else: logger.debug("Closed connection to: %s:%d" % (self.dstaddr, self.dstport)) @@ -279,9 +283,9 @@ class P2PConnection(asyncio.Protocol): if not is_mac_auth: raise ValueError("invalid v2 mac tag in handshake authentication") self.recvbuf = self.recvbuf[length:] - while self.v2_state.tried_v2_handshake and self.queue_messages: - message = self.queue_messages.pop(0) - self.send_message(message) + if self.v2_state.tried_v2_handshake and self.on_connection_send_msg: + self.send_message(self.on_connection_send_msg) + self.on_connection_send_msg = None # Socket read methods @@ -350,7 +354,8 @@ class P2PConnection(asyncio.Protocol): self._log_message("receive", t) self.on_message(t) except Exception as e: - logger.exception('Error reading message:', repr(e)) + if not self.reconnect: + logger.exception('Error reading message:', repr(e)) raise def on_message(self, message): @@ -549,6 +554,12 @@ class P2PInterface(P2PConnection): def on_version(self, message): assert message.nVersion >= MIN_P2P_VERSION_SUPPORTED, "Version {} received. Test framework only supports versions greater than {}".format(message.nVersion, MIN_P2P_VERSION_SUPPORTED) + # reconnection using v1 P2P has happened since version message can be processed, previously unsent version message is sent using v1 P2P here + if self.reconnect: + if self.on_connection_send_msg: + self.send_message(self.on_connection_send_msg) + self.on_connection_send_msg = None + self.reconnect = False if message.nVersion >= 70016 and self.wtxidrelay: self.send_message(msg_wtxidrelay()) if self.support_addrv2: @@ -721,6 +732,11 @@ class NetworkThread(threading.Thread): if addr is None: addr = '127.0.0.1' + def exception_handler(loop, context): + if not p2p.reconnect: + loop.default_exception_handler(context) + + cls.network_event_loop.set_exception_handler(exception_handler) coroutine = cls.create_listen_server(addr, port, callback, p2p) cls.network_event_loop.call_soon_threadsafe(cls.network_event_loop.create_task, coroutine) @@ -734,7 +750,9 @@ class NetworkThread(threading.Thread): protocol function from that dict, and returns it so the event loop can start executing it.""" response = cls.protos.get((addr, port)) - cls.protos[(addr, port)] = None + # remove protocol function from dict only when reconnection doesn't need to happen/already happened + if not proto.reconnect: + cls.protos[(addr, port)] = None return response if (addr, port) not in cls.listeners: diff --git a/test/functional/test_framework/test_node.py b/test/functional/test_framework/test_node.py index 444976e54f7..9088783d122 100755 --- a/test/functional/test_framework/test_node.py +++ b/test/functional/test_framework/test_node.py @@ -702,7 +702,7 @@ class TestNode(): self.addconnection('%s:%d' % (address, port), connection_type) p2p_conn.p2p_connected_to_node = False - p2p_conn.peer_accept_connection(connect_cb=addconnection_callback, connect_id=p2p_idx + 1, net=self.chain, timeout_factor=self.timeout_factor, supports_v2_p2p=supports_v2_p2p, **kwargs)() + p2p_conn.peer_accept_connection(connect_cb=addconnection_callback, connect_id=p2p_idx + 1, net=self.chain, timeout_factor=self.timeout_factor, supports_v2_p2p=supports_v2_p2p, reconnect=False, **kwargs)() if connection_type == "feeler": # feeler connections are closed as soon as the node receives a `version` message