Skip to content

Commit 6037944

Browse files
committed
Merge branch 'review/pr-901'
Closes #901
2 parents 4d3553b + 6909d3d commit 6037944

2 files changed

Lines changed: 96 additions & 21 deletions

File tree

meshtastic/tcp_interface.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import contextlib
55
import logging
66
import socket
7+
import threading
78
import time
89
from typing import Optional
910

@@ -12,17 +13,18 @@
1213
DEFAULT_TCP_PORT = 4403
1314
logger = logging.getLogger(__name__)
1415

16+
1517
class TCPInterface(StreamInterface):
1618
"""Interface class for meshtastic devices over a TCP link"""
1719

1820
def __init__(
1921
self,
2022
hostname: str,
2123
debugOut=None,
22-
noProto: bool=False,
23-
connectNow: bool=True,
24-
portNumber: int=DEFAULT_TCP_PORT,
25-
noNodes:bool=False,
24+
noProto: bool = False,
25+
connectNow: bool = True,
26+
portNumber: int = DEFAULT_TCP_PORT,
27+
noNodes: bool = False,
2628
timeout: int = 300,
2729
):
2830
"""Constructor, opens a connection to a specified IP address/hostname
@@ -35,8 +37,15 @@ def __init__(
3537
self.portNumber: int = portNumber
3638

3739
self.socket: Optional[socket.socket] = None
40+
self.reconnectLock = threading.Lock()
3841

39-
super().__init__(debugOut=debugOut, noProto=noProto, connectNow=connectNow, noNodes=noNodes, timeout=timeout)
42+
super().__init__(
43+
debugOut=debugOut,
44+
noProto=noProto,
45+
connectNow=connectNow,
46+
noNodes=noNodes,
47+
timeout=timeout,
48+
)
4049

4150
def __repr__(self):
4251
rep = f"TCPInterface({self.hostname!r}"
@@ -67,18 +76,20 @@ def connect(self) -> None:
6776

6877
def myConnect(self) -> None:
6978
"""Connect to socket (without attempting to start the interface's receive thread)"""
70-
logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe]
79+
logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe]
7180
server_address = (self.hostname, self.portNumber)
7281
self.socket = socket.create_connection(server_address)
7382

7483
def close(self) -> None:
75-
"""Close a connection to the device"""
84+
"""Close a connection to the device."""
7685
logger.debug("Closing TCP stream")
7786
# Sometimes the socket read might be blocked in the reader thread.
7887
# Therefore force a shutdown first to unblock reader thread reads.
7988
self._wantExit = True
8089
if self.socket is not None:
81-
with contextlib.suppress(Exception): # Ignore errors in shutdown, because we might have a race with the server
90+
with contextlib.suppress(
91+
Exception
92+
): # Ignore errors in shutdown, because we might have a race with the server
8293
self._socket_shutdown()
8394
with contextlib.suppress(Exception):
8495
self.socket.close()
@@ -87,29 +98,52 @@ def close(self) -> None:
8798
super().close()
8899

89100
def _writeBytes(self, b: bytes) -> None:
90-
"""Write an array of bytes to our stream and flush"""
101+
"""Write an array of bytes to our stream"""
91102
if self.socket is not None:
92-
self.socket.send(b)
103+
try:
104+
self.socket.sendall(b)
105+
except OSError as e:
106+
logger.error(f"Socket send error, reconnecting: {e}")
107+
if not self._wantExit:
108+
self._reconnect()
109+
raise
93110

94111
def _readBytes(self, length) -> Optional[bytes]:
95112
"""Read an array of bytes from our stream"""
96113
if self.socket is not None:
97114
data = self.socket.recv(length)
98115
# empty byte indicates a disconnected socket,
99116
# we need to handle it to avoid an infinite loop reading from null socket
100-
if data == b'':
101-
logger.debug("dead socket, re-connecting")
102-
# cleanup and reconnect socket without breaking reader thread
103-
with contextlib.suppress(Exception):
104-
self._socket_shutdown()
105-
self.socket.close()
106-
self.socket = None
107-
time.sleep(1)
108-
self.myConnect()
109-
self._startConfig()
110-
return None
117+
if data == b"":
118+
logger.debug("Closed socket, re-connecting")
119+
if not self._wantExit:
120+
self._reconnect()
111121
return data
112122

113123
# no socket, break reader thread
114124
self._wantExit = True
115125
return None
126+
127+
def _reconnect(self) -> None:
128+
"""Reconnect to the socket"""
129+
# Save the socket reference before attempting to acquire the lock.
130+
sock = self.socket
131+
start_config = False
132+
with self.reconnectLock:
133+
if self._wantExit:
134+
return
135+
# Don't reconnect: someone else already did it.
136+
if sock is not self.socket:
137+
return
138+
139+
with contextlib.suppress(Exception):
140+
self._socket_shutdown()
141+
if self.socket is not None:
142+
self.socket.close()
143+
self.socket = None
144+
time.sleep(1)
145+
self.myConnect()
146+
start_config = True
147+
148+
if start_config and not self._wantExit and self.socket is not None:
149+
self._startConfig()

meshtastic/tests/test_tcp_interface.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,44 @@ def test_TCPInterface_close_shutdowns_socket_before_super_close():
7676
assert call_order == ["shutdown", "super_close"]
7777
sock.close.assert_called_once()
7878
assert iface.socket is None
79+
80+
81+
@pytest.mark.unit
82+
def test_TCPInterface_reconnect():
83+
"""Test that _reconnect correctly reconnects"""
84+
with patch("socket.socket") as mock_socket:
85+
with patch("time.sleep"):
86+
iface = TCPInterface(hostname="localhost", noProto=True)
87+
old_socket = iface.socket
88+
assert old_socket is not None
89+
90+
iface._reconnect()
91+
92+
assert old_socket.close.called
93+
# We expect socket class to be instantiated at least twice (init + reconnect)
94+
assert mock_socket.call_count >= 2
95+
96+
97+
@pytest.mark.unit
98+
def test_TCPInterface_writeBytes_reconnects():
99+
"""Test that _writeBytes reconnects and re-raises on OSError."""
100+
with patch("socket.socket"):
101+
iface = TCPInterface(hostname="localhost", noProto=True)
102+
iface.socket.sendall.side_effect = OSError("Broken pipe")
103+
104+
with patch.object(iface, "_reconnect") as mock_reconnect:
105+
with pytest.raises(OSError, match="Broken pipe"):
106+
iface._writeBytes(b"some data")
107+
mock_reconnect.assert_called_once()
108+
109+
110+
@pytest.mark.unit
111+
def test_TCPInterface_readBytes_reconnects():
112+
"""Test that _readBytes calls _reconnect on empty bytes"""
113+
iface = TCPInterface(hostname="localhost", noProto=True, connectNow=False)
114+
iface.socket = MagicMock()
115+
iface.socket.recv.return_value = b""
116+
117+
with patch.object(iface, "_reconnect") as mock_reconnect:
118+
iface._readBytes(10)
119+
mock_reconnect.assert_called_once()

0 commit comments

Comments
 (0)