Skip to content

Commit 4486578

Browse files
authored
Merge branch 'main' into timestamp_ns
2 parents 4272252 + f871143 commit 4486578

File tree

4 files changed

+140
-20
lines changed

4 files changed

+140
-20
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,18 @@
7373
MQTT_PKT_TYPE_MASK = const(0xF0)
7474

7575

76+
CONNACK_ERROR_INCORRECT_PROTOCOL = const(0x01)
77+
CONNACK_ERROR_ID_REJECTED = const(0x02)
78+
CONNACK_ERROR_SERVER_UNAVAILABLE = const(0x03)
79+
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD = const(0x04)
80+
CONNACK_ERROR_UNAUTHORIZED = const(0x05)
81+
7682
CONNACK_ERRORS = {
77-
const(0x01): "Connection Refused - Incorrect Protocol Version",
78-
const(0x02): "Connection Refused - ID Rejected",
79-
const(0x03): "Connection Refused - Server unavailable",
80-
const(0x04): "Connection Refused - Incorrect username/password",
81-
const(0x05): "Connection Refused - Unauthorized",
83+
CONNACK_ERROR_INCORRECT_PROTOCOL: "Connection Refused - Incorrect Protocol Version",
84+
CONNACK_ERROR_ID_REJECTED: "Connection Refused - ID Rejected",
85+
CONNACK_ERROR_SERVER_UNAVAILABLE: "Connection Refused - Server unavailable",
86+
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD: "Connection Refused - Incorrect username/password",
87+
CONNACK_ERROR_UNAUTHORIZED: "Connection Refused - Unauthorized",
8288
}
8389

8490
_default_sock = None # pylint: disable=invalid-name
@@ -88,6 +94,10 @@
8894
class MMQTTException(Exception):
8995
"""MiniMQTT Exception class."""
9096

97+
def __init__(self, error, code=None):
98+
super().__init__(error, code)
99+
self.code = code
100+
91101

92102
class NullLogger:
93103
"""Fake logger class that does not do anything"""
@@ -397,21 +407,31 @@ def connect(
397407
)
398408
self._reset_reconnect_backoff()
399409
return ret
400-
except RuntimeError as e:
410+
except (MemoryError, OSError, RuntimeError) as e:
411+
if isinstance(e, RuntimeError) and e.args == ("pystack exhausted",):
412+
raise
401413
self.logger.warning(f"Socket error when connecting: {e}")
414+
last_exception = e
402415
backoff = False
403416
except MMQTTException as e:
404-
last_exception = e
417+
self._close_socket()
405418
self.logger.info(f"MMQT error: {e}")
419+
if e.code in [
420+
CONNACK_ERROR_INCORECT_USERNAME_PASSWORD,
421+
CONNACK_ERROR_UNAUTHORIZED,
422+
]:
423+
# No sense trying these again, re-raise
424+
raise
425+
last_exception = e
406426
backoff = True
407427

408428
if self._reconnect_attempts_max > 1:
409429
exc_msg = "Repeated connect failures"
410430
else:
411431
exc_msg = "Connect failure"
432+
412433
if last_exception:
413434
raise MMQTTException(exc_msg) from last_exception
414-
415435
raise MMQTTException(exc_msg)
416436

417437
# pylint: disable=too-many-branches, too-many-statements, too-many-locals
@@ -508,7 +528,7 @@ def _connect(
508528
rc = self._sock_exact_recv(3)
509529
assert rc[0] == 0x02
510530
if rc[2] != 0x00:
511-
raise MMQTTException(CONNACK_ERRORS[rc[2]])
531+
raise MMQTTException(CONNACK_ERRORS[rc[2]], code=rc[2])
512532
self._is_connected = True
513533
result = rc[0] & 1
514534
if self.on_connect is not None:
@@ -522,6 +542,12 @@ def _connect(
522542
f"No data received from broker for {self._recv_timeout} seconds."
523543
)
524544

545+
def _close_socket(self):
546+
if self._sock:
547+
self.logger.debug("Closing socket")
548+
self._connection_manager.close_socket(self._sock)
549+
self._sock = None
550+
525551
# pylint: disable=no-self-use
526552
def _encode_remaining_length(
527553
self, fixed_header: bytearray, remaining_length: int
@@ -550,8 +576,7 @@ def disconnect(self) -> None:
550576
self._sock.send(MQTT_DISCONNECT)
551577
except RuntimeError as e:
552578
self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
553-
self.logger.debug("Closing socket")
554-
self._connection_manager.close_socket(self._sock)
579+
self._close_socket()
555580
self._is_connected = False
556581
self._subscribed_topics = []
557582
self._last_msg_sent_timestamp = 0
@@ -568,6 +593,7 @@ def ping(self) -> list[int]:
568593
self._sock.send(MQTT_PINGREQ)
569594
ping_timeout = self.keep_alive
570595
stamp = ticks_ms()
596+
571597
self._last_msg_sent_timestamp = stamp
572598
rc, rcs = None, []
573599
while rc != MQTT_PINGRESP:
@@ -946,7 +972,7 @@ def _wait_for_msg(self, timeout: Optional[float] = None) -> Optional[int]:
946972
res = self._sock_exact_recv(1)
947973
except self._socket_pool.timeout:
948974
return None
949-
else: # socketpool, esp32spi
975+
else: # socketpool, esp32spi, wiznet5k
950976
try:
951977
res = self._sock_exact_recv(1, timeout=timeout)
952978
except OSError as error:
@@ -1035,14 +1061,14 @@ def _sock_exact_recv(
10351061
"""
10361062
stamp = ticks_ms()
10371063
if not self._backwards_compatible_sock:
1038-
# CPython/Socketpool Impl.
1064+
# CPython, socketpool, esp32spi, wiznet5k
10391065
rc = bytearray(bufsize)
10401066
mv = memoryview(rc)
10411067
recv_len = self._sock.recv_into(rc, bufsize)
10421068
to_read = bufsize - recv_len
10431069
if to_read < 0:
10441070
raise MMQTTException(f"negative number of bytes to read: {to_read}")
1045-
read_timeout = timeout if timeout is not None else self.keep_alive
1071+
read_timeout = timeout if timeout is not None else self._recv_timeout
10461072
mv = mv[recv_len:]
10471073
while to_read > 0:
10481074
recv_len = self._sock.recv_into(mv, to_read)
@@ -1052,8 +1078,8 @@ def _sock_exact_recv(
10521078
raise MMQTTException(
10531079
f"Unable to receive {to_read} bytes within {read_timeout} seconds."
10541080
)
1055-
else: # ESP32SPI Impl.
1056-
# This will timeout with socket timeout (not keepalive timeout)
1081+
else: # Legacy: fona, esp_atcontrol
1082+
# This will time out with socket timeout (not receive timeout).
10571083
rc = self._sock.recv(bufsize)
10581084
if not rc:
10591085
self.logger.debug("_sock_exact_recv timeout")
@@ -1063,7 +1089,7 @@ def _sock_exact_recv(
10631089
# or raise exception if wait longer than read_timeout
10641090
to_read = bufsize - len(rc)
10651091
assert to_read >= 0
1066-
read_timeout = self.keep_alive
1092+
read_timeout = self._recv_timeout
10671093
while to_read > 0:
10681094
recv = self._sock.recv(to_read)
10691095
to_read -= len(recv)

tests/test_backoff.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,24 @@ class TestExpBackOff:
1818
"""basic exponential back-off test"""
1919

2020
connect_times = []
21+
raise_exception = None
2122

2223
# pylint: disable=unused-argument
2324
def fake_connect(self, arg):
2425
"""connect() replacement that records the call times and always raises OSError"""
2526
self.connect_times.append(time.monotonic())
26-
raise OSError("this connect failed")
27+
raise self.raise_exception
2728

2829
def test_failing_connect(self) -> None:
2930
"""test that exponential back-off is used when connect() always raises OSError"""
3031
# use RFC 1918 address to avoid dealing with IPv6 in the call list below
3132
host = "172.40.0.3"
3233
port = 1883
34+
self.connect_times = []
35+
error_code = MQTT.CONNACK_ERROR_SERVER_UNAVAILABLE
36+
self.raise_exception = MQTT.MMQTTException(
37+
MQTT.CONNACK_ERRORS[error_code], code=error_code
38+
)
3339

3440
with patch.object(socket.socket, "connect") as mock_method:
3541
mock_method.side_effect = self.fake_connect
@@ -45,6 +51,7 @@ def test_failing_connect(self) -> None:
4551
print("connecting")
4652
with pytest.raises(MQTT.MMQTTException) as context:
4753
mqtt_client.connect()
54+
assert mqtt_client._sock is None
4855
assert "Repeated connect failures" in str(context)
4956

5057
mock_method.assert_called()
@@ -54,3 +61,34 @@ def test_failing_connect(self) -> None:
5461
print(f"connect() call times: {self.connect_times}")
5562
for i in range(1, connect_retries):
5663
assert self.connect_times[i] >= 2**i
64+
65+
def test_unauthorized(self) -> None:
66+
"""test that exponential back-off is used when connect() always raises OSError"""
67+
# use RFC 1918 address to avoid dealing with IPv6 in the call list below
68+
host = "172.40.0.3"
69+
port = 1883
70+
self.connect_times = []
71+
error_code = MQTT.CONNACK_ERROR_UNAUTHORIZED
72+
self.raise_exception = MQTT.MMQTTException(
73+
MQTT.CONNACK_ERRORS[error_code], code=error_code
74+
)
75+
76+
with patch.object(socket.socket, "connect") as mock_method:
77+
mock_method.side_effect = self.fake_connect
78+
79+
connect_retries = 3
80+
mqtt_client = MQTT.MQTT(
81+
broker=host,
82+
port=port,
83+
socket_pool=socket,
84+
ssl_context=ssl.create_default_context(),
85+
connect_retries=connect_retries,
86+
)
87+
print("connecting")
88+
with pytest.raises(MQTT.MMQTTException) as context:
89+
mqtt_client.connect()
90+
assert mqtt_client._sock is None
91+
assert "Connection Refused - Unauthorized" in str(context)
92+
93+
mock_method.assert_called()
94+
assert len(self.connect_times) == 1

tests/test_port_ssl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,6 @@ def test_tls_without_ssl_context(self) -> None:
120120
connect_retries=1,
121121
)
122122

123-
with pytest.raises(AttributeError) as context:
123+
with pytest.raises(ValueError) as context:
124124
mqtt_client.connect()
125-
assert "ssl_context must be set" in str(context)
125+
assert "ssl_context must be provided if using ssl" in str(context)

tests/test_recv_timeout.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-FileCopyrightText: 2024 Vladimír Kotal
2+
#
3+
# SPDX-License-Identifier: Unlicense
4+
5+
"""receive timeout tests"""
6+
7+
import socket
8+
import time
9+
from unittest import TestCase, main
10+
from unittest.mock import Mock
11+
12+
import adafruit_minimqtt.adafruit_minimqtt as MQTT
13+
14+
15+
class RecvTimeout(TestCase):
16+
"""This class contains tests for receive timeout handling."""
17+
18+
def test_recv_timeout_vs_keepalive(self) -> None:
19+
"""verify that receive timeout as used via ping() is different to keep alive timeout"""
20+
21+
for side_effect in [lambda ret_buf, buf_size: 0, socket.timeout]:
22+
with self.subTest():
23+
host = "127.0.0.1"
24+
25+
recv_timeout = 4
26+
keep_alive = recv_timeout * 2
27+
mqtt_client = MQTT.MQTT(
28+
broker=host,
29+
socket_pool=socket,
30+
connect_retries=1,
31+
socket_timeout=recv_timeout // 2,
32+
recv_timeout=recv_timeout,
33+
keep_alive=keep_alive,
34+
)
35+
36+
# Create a mock socket that will accept anything and return nothing.
37+
socket_mock = Mock()
38+
socket_mock.recv_into = Mock(side_effect=side_effect)
39+
# pylint: disable=protected-access
40+
mqtt_client._sock = socket_mock
41+
42+
mqtt_client._connected = lambda: True
43+
start = time.monotonic()
44+
with self.assertRaises(MQTT.MMQTTException):
45+
mqtt_client.ping()
46+
47+
# Verify the mock interactions.
48+
socket_mock.send.assert_called_once()
49+
socket_mock.recv_into.assert_called()
50+
51+
now = time.monotonic()
52+
assert recv_timeout <= (now - start) < keep_alive
53+
54+
55+
if __name__ == "__main__":
56+
main()

0 commit comments

Comments
 (0)