From 07172f88f38e806fc95b102e3a83eae0c2527320 Mon Sep 17 00:00:00 2001 From: Stephen Thorne Date: Thu, 5 Feb 2026 21:53:18 +0100 Subject: [PATCH] Give TCPInterface reconnect logic on write errors * Moving to socket.sendall() is safer, as sendall will send the entire buffer, while send() would return the number of bytes sent and require being called multiple times if the buffer was full. * On exceptions: reconnect to the server. * On reconnection: make sure using a lock that there isn't a race between the readers and the writers triggering a reconnect. --- meshtastic/tcp_interface.py | 68 +++++++++++++++++--------- meshtastic/tests/test_tcp_interface.py | 41 ++++++++++++++++ protobufs | 1 - 3 files changed, 87 insertions(+), 23 deletions(-) delete mode 160000 protobufs diff --git a/meshtastic/tcp_interface.py b/meshtastic/tcp_interface.py index 732f37ef3..5d2792986 100644 --- a/meshtastic/tcp_interface.py +++ b/meshtastic/tcp_interface.py @@ -4,6 +4,7 @@ import contextlib import logging import socket +import threading import time from typing import Optional @@ -12,6 +13,7 @@ DEFAULT_TCP_PORT = 4403 logger = logging.getLogger(__name__) + class TCPInterface(StreamInterface): """Interface class for meshtastic devices over a TCP link""" @@ -19,10 +21,10 @@ def __init__( self, hostname: str, debugOut=None, - noProto: bool=False, - connectNow: bool=True, - portNumber: int=DEFAULT_TCP_PORT, - noNodes:bool=False, + noProto: bool = False, + connectNow: bool = True, + portNumber: int = DEFAULT_TCP_PORT, + noNodes: bool = False, timeout: int = 300, ): """Constructor, opens a connection to a specified IP address/hostname @@ -38,13 +40,20 @@ def __init__( self.portNumber: int = portNumber self.socket: Optional[socket.socket] = None + self.reconnectLock = threading.Lock() if connectNow: self.myConnect() else: self.socket = None - super().__init__(debugOut=debugOut, noProto=noProto, connectNow=connectNow, noNodes=noNodes, timeout=timeout) + super().__init__( + debugOut=debugOut, + noProto=noProto, + connectNow=connectNow, + noNodes=noNodes, + timeout=timeout, + ) def __repr__(self): rep = f"TCPInterface({self.hostname!r}" @@ -69,29 +78,35 @@ def _socket_shutdown(self) -> None: self.socket.shutdown(socket.SHUT_RDWR) def myConnect(self) -> None: - """Connect to socket""" - logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe] + """Connect to socket.""" + logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe] server_address = (self.hostname, self.portNumber) self.socket = socket.create_connection(server_address) def close(self) -> None: - """Close a connection to the device""" + """Close a connection to the device.""" logger.debug("Closing TCP stream") super().close() # Sometimes the socket read might be blocked in the reader thread. # Therefore we force the shutdown by closing the socket here self._wantExit = True if self.socket is not None: - with contextlib.suppress(Exception): # Ignore errors in shutdown, because we might have a race with the server + with contextlib.suppress( + Exception + ): # Ignore errors in shutdown, because we might have a race with the server self._socket_shutdown() self.socket.close() self.socket = None def _writeBytes(self, b: bytes) -> None: - """Write an array of bytes to our stream and flush""" + """Write an array of bytes to our stream""" if self.socket is not None: - self.socket.send(b) + try: + self.socket.sendall(b) + except OSError as e: + logger.error(f"Socket send error, reconnecting: {e}") + self._reconnect() def _readBytes(self, length) -> Optional[bytes]: """Read an array of bytes from our stream""" @@ -99,19 +114,28 @@ def _readBytes(self, length) -> Optional[bytes]: data = self.socket.recv(length) # empty byte indicates a disconnected socket, # we need to handle it to avoid an infinite loop reading from null socket - if data == b'': - logger.debug("dead socket, re-connecting") - # cleanup and reconnect socket without breaking reader thread - with contextlib.suppress(Exception): - self._socket_shutdown() - self.socket.close() - self.socket = None - time.sleep(1) - self.myConnect() - self._startConfig() - return None + if data == b"": + logger.debug("Closed socket, re-connecting") + self._reconnect() return data # no socket, break reader thread self._wantExit = True return None + + def _reconnect(self) -> None: + """Reconnect to the socket""" + # Save the socket reference before attempting to acquire the lock. + sock = self.socket + with self.reconnectLock: + # Don't reconnect: someone else already did it. + if sock is not self.socket: + return + + with contextlib.suppress(Exception): + self._socket_shutdown() + self.socket.close() + self.socket = None + time.sleep(1) + self.myConnect() + self._startConfig() diff --git a/meshtastic/tests/test_tcp_interface.py b/meshtastic/tests/test_tcp_interface.py index 44e79de50..e123833ef 100644 --- a/meshtastic/tests/test_tcp_interface.py +++ b/meshtastic/tests/test_tcp_interface.py @@ -54,3 +54,44 @@ def test_TCPInterface_without_connecting(): with patch("socket.socket"): iface = TCPInterface(hostname="localhost", noProto=True, connectNow=False) assert iface.socket is None + + +@pytest.mark.unit +def test_TCPInterface_reconnect(): + """Test that _reconnect correctly reconnects""" + with patch("socket.socket") as mock_socket: + with patch("time.sleep"): + iface = TCPInterface(hostname="localhost", noProto=True) + old_socket = iface.socket + assert old_socket is not None + + iface._reconnect() + + assert old_socket.close.called + # We expect socket class to be instantiated at least twice (init + reconnect) + assert mock_socket.call_count >= 2 + + +@pytest.mark.unit +def test_TCPInterface_writeBytes_reconnects(): + """Test that _writeBytes calls _reconnect on OSError""" + with patch("socket.socket"): + iface = TCPInterface(hostname="localhost", noProto=True) + iface.socket.sendall.side_effect = OSError("Broken pipe") + + with patch.object(iface, '_reconnect') as mock_reconnect: + iface._writeBytes(b"some data") + mock_reconnect.assert_called_once() + + +@pytest.mark.unit +def test_TCPInterface_readBytes_reconnects(): + """Test that _readBytes calls _reconnect on empty bytes""" + with patch("socket.socket"): + iface = TCPInterface(hostname="localhost", noProto=True) + # Mock the socket instance on the interface + iface.socket.recv.return_value = b'' + + with patch.object(iface, '_reconnect') as mock_reconnect: + iface._readBytes(10) + mock_reconnect.assert_called_once() diff --git a/protobufs b/protobufs deleted file mode 160000 index 77c8329a5..000000000 --- a/protobufs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 77c8329a59a9c96a61c447b5d5f1a52ca583e4f2