diff --git a/meshtastic/tcp_interface.py b/meshtastic/tcp_interface.py index 732f37ef..5d279298 100644 --- a/meshtastic/tcp_interface.py +++ b/meshtastic/tcp_interface.py @@ -4,6 +4,7 @@ import contextlib import logging import socket +import threading import time from typing import Optional @@ -12,6 +13,7 @@ DEFAULT_TCP_PORT = 4403 logger = logging.getLogger(__name__) + class TCPInterface(StreamInterface): """Interface class for meshtastic devices over a TCP link""" @@ -19,10 +21,10 @@ def __init__( self, hostname: str, debugOut=None, - noProto: bool=False, - connectNow: bool=True, - portNumber: int=DEFAULT_TCP_PORT, - noNodes:bool=False, + noProto: bool = False, + connectNow: bool = True, + portNumber: int = DEFAULT_TCP_PORT, + noNodes: bool = False, timeout: int = 300, ): """Constructor, opens a connection to a specified IP address/hostname @@ -38,13 +40,20 @@ def __init__( self.portNumber: int = portNumber self.socket: Optional[socket.socket] = None + self.reconnectLock = threading.Lock() if connectNow: self.myConnect() else: self.socket = None - super().__init__(debugOut=debugOut, noProto=noProto, connectNow=connectNow, noNodes=noNodes, timeout=timeout) + super().__init__( + debugOut=debugOut, + noProto=noProto, + connectNow=connectNow, + noNodes=noNodes, + timeout=timeout, + ) def __repr__(self): rep = f"TCPInterface({self.hostname!r}" @@ -69,29 +78,35 @@ def _socket_shutdown(self) -> None: self.socket.shutdown(socket.SHUT_RDWR) def myConnect(self) -> None: - """Connect to socket""" - logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe] + """Connect to socket.""" + logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe] server_address = (self.hostname, self.portNumber) self.socket = socket.create_connection(server_address) def close(self) -> None: - """Close a connection to the device""" + """Close a connection to the device.""" logger.debug("Closing TCP stream") super().close() # Sometimes the socket read might be blocked in the reader thread. # Therefore we force the shutdown by closing the socket here self._wantExit = True if self.socket is not None: - with contextlib.suppress(Exception): # Ignore errors in shutdown, because we might have a race with the server + with contextlib.suppress( + Exception + ): # Ignore errors in shutdown, because we might have a race with the server self._socket_shutdown() self.socket.close() self.socket = None def _writeBytes(self, b: bytes) -> None: - """Write an array of bytes to our stream and flush""" + """Write an array of bytes to our stream""" if self.socket is not None: - self.socket.send(b) + try: + self.socket.sendall(b) + except OSError as e: + logger.error(f"Socket send error, reconnecting: {e}") + self._reconnect() def _readBytes(self, length) -> Optional[bytes]: """Read an array of bytes from our stream""" @@ -99,19 +114,28 @@ def _readBytes(self, length) -> Optional[bytes]: data = self.socket.recv(length) # empty byte indicates a disconnected socket, # we need to handle it to avoid an infinite loop reading from null socket - if data == b'': - logger.debug("dead socket, re-connecting") - # cleanup and reconnect socket without breaking reader thread - with contextlib.suppress(Exception): - self._socket_shutdown() - self.socket.close() - self.socket = None - time.sleep(1) - self.myConnect() - self._startConfig() - return None + if data == b"": + logger.debug("Closed socket, re-connecting") + self._reconnect() return data # no socket, break reader thread self._wantExit = True return None + + def _reconnect(self) -> None: + """Reconnect to the socket""" + # Save the socket reference before attempting to acquire the lock. + sock = self.socket + with self.reconnectLock: + # Don't reconnect: someone else already did it. + if sock is not self.socket: + return + + with contextlib.suppress(Exception): + self._socket_shutdown() + self.socket.close() + self.socket = None + time.sleep(1) + self.myConnect() + self._startConfig() diff --git a/meshtastic/tests/test_tcp_interface.py b/meshtastic/tests/test_tcp_interface.py index 44e79de5..e123833e 100644 --- a/meshtastic/tests/test_tcp_interface.py +++ b/meshtastic/tests/test_tcp_interface.py @@ -54,3 +54,44 @@ def test_TCPInterface_without_connecting(): with patch("socket.socket"): iface = TCPInterface(hostname="localhost", noProto=True, connectNow=False) assert iface.socket is None + + +@pytest.mark.unit +def test_TCPInterface_reconnect(): + """Test that _reconnect correctly reconnects""" + with patch("socket.socket") as mock_socket: + with patch("time.sleep"): + iface = TCPInterface(hostname="localhost", noProto=True) + old_socket = iface.socket + assert old_socket is not None + + iface._reconnect() + + assert old_socket.close.called + # We expect socket class to be instantiated at least twice (init + reconnect) + assert mock_socket.call_count >= 2 + + +@pytest.mark.unit +def test_TCPInterface_writeBytes_reconnects(): + """Test that _writeBytes calls _reconnect on OSError""" + with patch("socket.socket"): + iface = TCPInterface(hostname="localhost", noProto=True) + iface.socket.sendall.side_effect = OSError("Broken pipe") + + with patch.object(iface, '_reconnect') as mock_reconnect: + iface._writeBytes(b"some data") + mock_reconnect.assert_called_once() + + +@pytest.mark.unit +def test_TCPInterface_readBytes_reconnects(): + """Test that _readBytes calls _reconnect on empty bytes""" + with patch("socket.socket"): + iface = TCPInterface(hostname="localhost", noProto=True) + # Mock the socket instance on the interface + iface.socket.recv.return_value = b'' + + with patch.object(iface, '_reconnect') as mock_reconnect: + iface._readBytes(10) + mock_reconnect.assert_called_once() diff --git a/protobufs b/protobufs deleted file mode 160000 index 77c8329a..00000000 --- a/protobufs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 77c8329a59a9c96a61c447b5d5f1a52ca583e4f2