diff --git a/test/functional/test_framework/socks5.py b/test/functional/test_framework/socks5.py index 085c5a2e324..28aa2e8e383 100644 --- a/test/functional/test_framework/socks5.py +++ b/test/functional/test_framework/socks5.py @@ -49,8 +49,12 @@ def sendall(s, data): raise IOError('send() on socket returned 0') sent += n -def forward_sockets(a, b): - """Forward data received on socket a to socket b and vice versa, until EOF is received on one of the sockets.""" +def forward_sockets(a, b, wakeup_socket, serv): + """Forwards data between sockets a and b until EOF, error, or shutdown. + + Monitors wakeup_socket for a shutdown signal and checks serv.is_running() + to exit gracefully when the server is stopping. + """ # Mark as non-blocking so that we do not end up in a deadlock-like situation # where we block and wait on data from `a` while there is data ready to be # received on `b` and forwarded to `a`. And at the same time the application @@ -58,10 +62,14 @@ def forward_sockets(a, b): # respond. a.setblocking(False) b.setblocking(False) - sockets = [a, b] + sockets = [a, b, wakeup_socket] done = False while not done: - rlist, _, xlist = select.select(sockets, [], sockets) + # Blocking select with timeout + rlist, _, xlist = select.select(sockets, [], sockets, 2) + if not serv.is_running(): + logger.debug("forward_sockets: Exit due to shutdown") + return if len(xlist) > 0: raise IOError('Exceptional condition on socket') for s in rlist: @@ -71,7 +79,7 @@ def forward_sockets(a, b): break if s == a: sendall(b, data) - else: + elif s == b: sendall(a, data) # Implementation classes @@ -113,6 +121,11 @@ class Socks5Connection(): def __init__(self, serv, conn): self.serv = serv self.conn = conn + # Socket-pair used to wake up blocking forwarding select + # Note: a pipe could be used as well, but that does not work with select() on Windows + self.wakeup_socket_pair = socket.socketpair() + # Index of this handler (within the server) + self.handler_index = None def handle(self): """Handle socks5 request according to RFC1928.""" @@ -176,27 +189,63 @@ class Socks5Connection(): requested_to_addr = addr.decode("utf-8") requested_to = format_addr_port(requested_to_addr, port) - if self.serv.conf.destinations_factory is not None: - dest = self.serv.conf.destinations_factory(requested_to_addr, port) - if dest is not None: - logger.debug(f"Serving connection to {requested_to}, will redirect it to " - f"{dest['actual_to_addr']}:{dest['actual_to_port']} instead") - with socket.create_connection((dest["actual_to_addr"], dest["actual_to_port"])) as conn_to: - forward_sockets(self.conn, conn_to) + if self.serv.is_running(): + if self.serv.conf.destinations_factory is not None: + dest = self.serv.conf.destinations_factory(requested_to_addr, port) + if dest is not None: + logger.debug(f"Serving connection to {requested_to}, will redirect it to " + f"{dest['actual_to_addr']}:{dest['actual_to_port']} instead") + with socket.create_connection((dest["actual_to_addr"], dest["actual_to_port"])) as conn_to: + forward_sockets(self.conn, conn_to, self.wakeup_socket_pair[1], self.serv) + conn_to.close() + else: + logger.debug(f"Can't serve the connection to {requested_to}: the destinations factory returned None") else: - logger.debug(f"Can't serve the connection to {requested_to}: the destinations factory returned None") - else: - logger.debug(f"Can't serve the connection to {requested_to}: no destinations factory") + logger.debug(f"Can't serve the connection to {requested_to}: no destinations factory") # Fall through to disconnect except Exception as e: - logger.exception("socks5 request handling failed.") - self.serv.queue.put(e) + logger.exception(f"socks5 request handling failed (running {self.serv.is_running()})") + if self.serv.is_running(): + self.serv.queue.put(e) finally: if not self.serv.keep_alive: self.conn.close() else: logger.debug("Keeping client connection alive") + s0 = self.wakeup_socket_pair[0] + s1 = self.wakeup_socket_pair[1] + self.wakeup_socket_pair = None + try: + s0.close() + s1.close() + except OSError: + pass + self.serv.remove_handler(self.handler_index) + self.handler_index = None + + def wakeup(self): + # Wake up the blocking forwarding select by writing to the wake-up socket + try: + socket_pair = self.wakeup_socket_pair + if socket_pair is not None: + socket_pair[0].send("CloseWakeup".encode()) + logger.debug("Waking up forwarding thread") + except OSError as e: + logger.warning(f"Error waking up forwarding thread: {e}") + pass + + +# Wrapper for thread.join(), which may throw for daemon threads (in late stages of finalization). +# Return True if the thread is no longer active (join succeeded), False otherwise +# See PR #34863 for more details on using daemon threads. +def try_join_daemon_thread(thread, timeout=0) -> bool: + try: + thread.join(timeout=timeout) + return not thread.is_alive() + except Exception as e: + logger.debug(f"Exception in thread.join, {e}") + return True class Socks5Server(): def __init__(self, conf): @@ -212,31 +261,69 @@ class Socks5Server(): # to reflect the actual bound address so callers can use it. self.conf.addr = self.s.getsockname() self.s.listen(5) - self.running = False + # Set to False when stop is initiated + self._running = False + self._running_lock = threading.Lock() self.thread = None self.queue = queue.Queue() # report connections and exceptions to client self.keep_alive = conf.keep_alive + # Store the background handlers, needed for clean shutdown + # Append-only array, completed handlers are set to None + self._handlers = [] + self._handlers_lock = threading.Lock() + + def is_running(self) -> bool: + with self._running_lock: + return self._running + + def set_running(self, new_value: bool): + with self._running_lock: + self._running = new_value def run(self): - while self.running: + while self.is_running(): (sockconn, _) = self.s.accept() - if self.running: + if self.is_running(): conn = Socks5Connection(self, sockconn) - thread = threading.Thread(None, conn.handle) - thread.daemon = True + # Use "daemon" threads, see PR #34863 for more discussion. + thread = threading.Thread(None, conn.handle, daemon=True) + with self._handlers_lock: + conn.handler_index = len(self._handlers) + self._handlers.append((thread, conn)) + assert(conn.handler_index < len(self._handlers)) thread.start() + def remove_handler(self, handler_index): + with self._handlers_lock: + if handler_index < len(self._handlers): + if self._handlers[handler_index] is not None: + self._handlers[handler_index] = None + logger.debug(f"Handler {handler_index} removed") + def start(self): - assert not self.running - self.running = True - self.thread = threading.Thread(None, self.run) - self.thread.daemon = True + assert not self.is_running() + self.set_running(True) + self.thread = threading.Thread(None, self.run, daemon=True) self.thread.start() def stop(self): - self.running = False + self.set_running(False) # connect to self to end run loop s = socket.socket(self.conf.af) s.connect(self.conf.addr) s.close() self.thread.join() + # if there are active handlers, close them + with self._handlers_lock: + items = list(self._handlers) + for i, item in enumerate(items): + if item is None: + continue + thread, conn = item + # check if thread is still active + if not try_join_daemon_thread(thread, timeout=0): + conn.wakeup() + if try_join_daemon_thread(thread, timeout=2): + logger.debug(f"Stop(): Handler {i} thread joined") + else: + logger.warning(f"Stop(): Handler thread {i} didn't finish after force close")