Skip to content

Commit 40a0019

Browse files
committed
Merge branch 'main' into loop_vs_keep_alive
# Conflicts: # adafruit_minimqtt/adafruit_minimqtt.py
2 parents d82465d + e19ece6 commit 40a0019

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,13 @@ def loop(self, timeout: float = 0) -> Optional[list[int]]:
10341034
:param float timeout: return after this timeout, in seconds.
10351035
10361036
"""
1037+
if timeout < self._socket_timeout:
1038+
raise MMQTTException(
1039+
# pylint: disable=consider-using-f-string
1040+
"loop timeout ({}) must be bigger ".format(timeout)
1041+
+ "than socket timeout ({}))".format(self._socket_timeout)
1042+
)
1043+
10371044
self._connected()
10381045
self.logger.debug(f"waiting for messages for {timeout} seconds")
10391046

@@ -1065,11 +1072,13 @@ def loop(self, timeout: float = 0) -> Optional[list[int]]:
10651072

10661073
return rcs if rcs else None
10671074

1068-
def _wait_for_msg(self) -> Optional[int]:
1075+
def _wait_for_msg(self, timeout: Optional[float] = None) -> Optional[int]:
10691076
# pylint: disable = too-many-return-statements
10701077

10711078
"""Reads and processes network events.
10721079
Return the packet type or None if there is nothing to be received.
1080+
1081+
:param float timeout: return after this timeout, in seconds.
10731082
"""
10741083
# CPython socket module contains a timeout attribute
10751084
if hasattr(self._socket_pool, "timeout"):
@@ -1079,7 +1088,7 @@ def _wait_for_msg(self) -> Optional[int]:
10791088
return None
10801089
else: # socketpool, esp32spi
10811090
try:
1082-
res = self._sock_exact_recv(1)
1091+
res = self._sock_exact_recv(1, timeout=timeout)
10831092
except OSError as error:
10841093
if error.errno in (errno.ETIMEDOUT, errno.EAGAIN):
10851094
# raised by a socket timeout if 0 bytes were present
@@ -1148,7 +1157,9 @@ def _decode_remaining_length(self) -> int:
11481157
return n
11491158
sh += 7
11501159

1151-
def _sock_exact_recv(self, bufsize: int) -> bytearray:
1160+
def _sock_exact_recv(
1161+
self, bufsize: int, timeout: Optional[float] = None
1162+
) -> bytearray:
11521163
"""Reads _exact_ number of bytes from the connected socket. Will only return
11531164
bytearray with the exact number of bytes requested.
11541165
@@ -1159,6 +1170,7 @@ def _sock_exact_recv(self, bufsize: int) -> bytearray:
11591170
bytes is returned or trigger a timeout exception.
11601171
11611172
:param int bufsize: number of bytes to receive
1173+
:param float timeout: timeout, in seconds. Defaults to keep_alive
11621174
:return: byte array
11631175
"""
11641176
stamp = self.get_monotonic_time()
@@ -1170,7 +1182,7 @@ def _sock_exact_recv(self, bufsize: int) -> bytearray:
11701182
to_read = bufsize - recv_len
11711183
if to_read < 0:
11721184
raise MMQTTException(f"negative number of bytes to read: {to_read}")
1173-
read_timeout = self.keep_alive
1185+
read_timeout = timeout if timeout is not None else self.keep_alive
11741186
mv = mv[recv_len:]
11751187
while to_read > 0:
11761188
recv_len = self._sock.recv_into(mv, to_read)

tests/test_loop.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ class Loop(TestCase):
108108
INITIAL_RCS_VAL = 42
109109
rcs_val = INITIAL_RCS_VAL
110110

111-
def fake_wait_for_msg(self):
111+
def fake_wait_for_msg(self, timeout=1):
112112
"""_wait_for_msg() replacement. Sleeps for 1 second and returns an integer."""
113-
time.sleep(1)
113+
time.sleep(timeout)
114114
retval = self.rcs_val
115115
self.rcs_val += 1
116116
return retval
@@ -151,13 +151,33 @@ def test_loop_basic(self) -> None:
151151

152152
# Check the return value.
153153
assert rcs is not None
154-
assert len(rcs) > 1
154+
assert len(rcs) >= 1
155155
expected_rc = self.INITIAL_RCS_VAL
156156
# pylint: disable=not-an-iterable
157157
for ret_code in rcs:
158158
assert ret_code == expected_rc
159159
expected_rc += 1
160160

161+
# pylint: disable=invalid-name
162+
def test_loop_timeout_vs_socket_timeout(self):
163+
"""
164+
loop() should throw MMQTTException if the timeout argument
165+
is bigger than the socket timeout.
166+
"""
167+
mqtt_client = MQTT.MQTT(
168+
broker="127.0.0.1",
169+
port=1883,
170+
socket_pool=socket,
171+
ssl_context=ssl.create_default_context(),
172+
socket_timeout=1,
173+
)
174+
175+
mqtt_client.is_connected = lambda: True
176+
with self.assertRaises(MQTT.MMQTTException) as context:
177+
mqtt_client.loop(timeout=0.5)
178+
179+
assert "loop timeout" in str(context.exception)
180+
161181
def test_loop_is_connected(self):
162182
"""
163183
loop() should throw MMQTTException if not connected

0 commit comments

Comments
 (0)