Skip to content

Commit 20ba1e3

Browse files
committed
Merge branch 'main' into time_monotonic_ns
# Conflicts: # adafruit_minimqtt/adafruit_minimqtt.py
2 parents 690007e + 6270110 commit 20ba1e3

File tree

3 files changed

+141
-12
lines changed

3 files changed

+141
-12
lines changed

adafruit_minimqtt/adafruit_minimqtt.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,10 @@ class MQTT:
168168
in seconds.
169169
:param int connect_retries: How many times to try to connect to the broker before giving up
170170
on connect or reconnect. Exponential backoff will be used for the retries.
171-
:param class user_data: arbitrary data to pass as a second argument to the callbacks.
171+
:param class user_data: arbitrary data to pass as a second argument to most of the callbacks.
172+
This works with all callbacks but the "on_message" and those added via add_topic_callback();
173+
for those, to get access to the user_data use the 'user_data' member of the MQTT object
174+
passed as 1st argument.
172175
:param bool use_imprecise_time: on boards without time.monotonic_ns() one has to set
173176
this to True in order to operate correctly over more than 24 days or so
174177
@@ -222,7 +225,7 @@ def __init__(
222225
self._recv_timeout = recv_timeout
223226

224227
self.keep_alive = keep_alive
225-
self._user_data = user_data
228+
self.user_data = user_data
226229
self._is_connected = False
227230
self._msg_size_lim = MQTT_MSG_SZ_LIM
228231
self._pid = 0
@@ -440,6 +443,11 @@ def add_topic_callback(self, mqtt_topic: str, callback_method) -> None:
440443
441444
:param str mqtt_topic: MQTT topic identifier.
442445
:param function callback_method: The callback method.
446+
447+
Expected method signature is ``on_message(client, topic, message)``
448+
To get access to the user_data, use the client argument.
449+
450+
If a callback is called for the topic, then any "on_message" callback will not be called.
443451
"""
444452
if mqtt_topic is None or callback_method is None:
445453
raise ValueError("MQTT topic and callback method must both be defined.")
@@ -464,6 +472,7 @@ def on_message(self):
464472
"""Called when a new message has been received on a subscribed topic.
465473
466474
Expected method signature is ``on_message(client, topic, message)``
475+
To get access to the user_data, use the client argument.
467476
"""
468477
return self._on_message
469478

@@ -665,7 +674,7 @@ def _connect(
665674
self._is_connected = True
666675
result = rc[0] & 1
667676
if self.on_connect is not None:
668-
self.on_connect(self, self._user_data, result, rc[2])
677+
self.on_connect(self, self.user_data, result, rc[2])
669678

670679
return result
671680

@@ -688,7 +697,7 @@ def disconnect(self) -> None:
688697
self._is_connected = False
689698
self._subscribed_topics = []
690699
if self.on_disconnect is not None:
691-
self.on_disconnect(self, self._user_data, 0)
700+
self.on_disconnect(self, self.user_data, 0)
692701

693702
def ping(self) -> list[int]:
694703
"""Pings the MQTT Broker to confirm if the broker is alive or if
@@ -784,7 +793,7 @@ def publish(
784793
self._sock.send(pub_hdr_var)
785794
self._sock.send(msg)
786795
if qos == 0 and self.on_publish is not None:
787-
self.on_publish(self, self._user_data, topic, self._pid)
796+
self.on_publish(self, self.user_data, topic, self._pid)
788797
if qos == 1:
789798
stamp = self.get_monotonic_time()
790799
while True:
@@ -796,7 +805,7 @@ def publish(
796805
rcv_pid = rcv_pid_buf[0] << 0x08 | rcv_pid_buf[1]
797806
if self._pid == rcv_pid:
798807
if self.on_publish is not None:
799-
self.on_publish(self, self._user_data, topic, rcv_pid)
808+
self.on_publish(self, self.user_data, topic, rcv_pid)
800809
return
801810

802811
if op is None:
@@ -876,7 +885,7 @@ def subscribe(self, topic: str, qos: int = 0) -> None:
876885

877886
for t, q in topics:
878887
if self.on_subscribe is not None:
879-
self.on_subscribe(self, self._user_data, t, q)
888+
self.on_subscribe(self, self.user_data, t, q)
880889
self._subscribed_topics.append(t)
881890
return
882891

@@ -934,7 +943,7 @@ def unsubscribe(self, topic: str) -> None:
934943
assert rc[1] == packet_id_bytes[0] and rc[2] == packet_id_bytes[1]
935944
for t in topics:
936945
if self.on_unsubscribe is not None:
937-
self.on_unsubscribe(self, self._user_data, t, self._pid)
946+
self.on_unsubscribe(self, self.user_data, t, self._pid)
938947
self._subscribed_topics.remove(t)
939948
return
940949

@@ -1013,7 +1022,7 @@ def loop(self, timeout: float = 0) -> Optional[list[int]]:
10131022
:param float timeout: return after this timeout, in seconds.
10141023
10151024
"""
1016-
1025+
self._connected()
10171026
self.logger.debug(f"waiting for messages for {timeout} seconds")
10181027
if self._timestamp == 0:
10191028
self._timestamp = self.get_monotonic_time()

examples/cpython/user_data.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-FileCopyrightText: 2023 Vladimír Kotal
2+
# SPDX-License-Identifier: Unlicense
3+
4+
# pylint: disable=logging-fstring-interpolation
5+
6+
"""
7+
Demonstrate on how to use user_data for various callbacks.
8+
"""
9+
10+
import logging
11+
import socket
12+
import ssl
13+
import sys
14+
15+
import adafruit_minimqtt.adafruit_minimqtt as MQTT
16+
17+
18+
# pylint: disable=unused-argument
19+
def on_connect(mqtt_client, user_data, flags, ret_code):
20+
"""
21+
connect callback
22+
"""
23+
logger = logging.getLogger(__name__)
24+
logger.debug("Connected to MQTT Broker!")
25+
logger.debug(f"Flags: {flags}\n RC: {ret_code}")
26+
27+
28+
# pylint: disable=unused-argument
29+
def on_subscribe(mqtt_client, user_data, topic, granted_qos):
30+
"""
31+
subscribe callback
32+
"""
33+
logger = logging.getLogger(__name__)
34+
logger.debug(f"Subscribed to {topic} with QOS level {granted_qos}")
35+
36+
37+
def on_message(client, topic, message):
38+
"""
39+
received message callback
40+
"""
41+
logger = logging.getLogger(__name__)
42+
logger.debug(f"New message on topic {topic}: {message}")
43+
44+
messages = client.user_data
45+
if not messages.get(topic):
46+
messages[topic] = []
47+
messages[topic].append(message)
48+
49+
50+
# pylint: disable=too-many-statements,too-many-locals
51+
def main():
52+
"""
53+
Main loop.
54+
"""
55+
56+
logging.basicConfig()
57+
logger = logging.getLogger(__name__)
58+
logger.setLevel(logging.DEBUG)
59+
60+
# dictionary/map of topic to list of messages
61+
messages = {}
62+
63+
# connect to MQTT broker
64+
mqtt = MQTT.MQTT(
65+
broker="172.40.0.3",
66+
port=1883,
67+
socket_pool=socket,
68+
ssl_context=ssl.create_default_context(),
69+
user_data=messages,
70+
)
71+
72+
mqtt.on_connect = on_connect
73+
mqtt.on_subscribe = on_subscribe
74+
mqtt.on_message = on_message
75+
76+
logger.info("Connecting to MQTT broker")
77+
mqtt.connect()
78+
logger.info("Subscribing")
79+
mqtt.subscribe("foo/#", qos=0)
80+
mqtt.add_topic_callback("foo/bar", on_message)
81+
82+
i = 0
83+
while True:
84+
i += 1
85+
logger.debug(f"Loop {i}")
86+
# Make sure to stay connected to the broker e.g. in case of keep alive.
87+
mqtt.loop(1)
88+
89+
for topic, msg_list in messages.items():
90+
logger.info(f"Got {len(msg_list)} messages from topic {topic}")
91+
for msg_cnt, msg in enumerate(msg_list):
92+
logger.debug(f"#{msg_cnt}: {msg}")
93+
94+
95+
if __name__ == "__main__":
96+
try:
97+
main()
98+
except KeyboardInterrupt:
99+
sys.exit(0)

tests/test_loop.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,21 @@ def test_loop_basic(self) -> None:
4444
ssl_context=ssl.create_default_context(),
4545
)
4646

47-
with patch.object(mqtt_client, "_wait_for_msg") as mock_method:
48-
mock_method.side_effect = self.fake_wait_for_msg
47+
with patch.object(
48+
mqtt_client, "_wait_for_msg"
49+
) as wait_for_msg_mock, patch.object(
50+
mqtt_client, "is_connected"
51+
) as is_connected_mock:
52+
wait_for_msg_mock.side_effect = self.fake_wait_for_msg
53+
is_connected_mock.side_effect = lambda: True
4954

5055
time_before = time.monotonic()
5156
timeout = random.randint(3, 8)
5257
rcs = mqtt_client.loop(timeout=timeout)
5358
time_after = time.monotonic()
5459

5560
assert time_after - time_before >= timeout
56-
mock_method.assert_called()
61+
wait_for_msg_mock.assert_called()
5762

5863
# Check the return value.
5964
assert rcs is not None
@@ -63,6 +68,22 @@ def test_loop_basic(self) -> None:
6368
assert ret_code == expected_rc
6469
expected_rc += 1
6570

71+
def test_loop_is_connected(self):
72+
"""
73+
loop() should throw MMQTTException if not connected
74+
"""
75+
mqtt_client = MQTT.MQTT(
76+
broker="127.0.0.1",
77+
port=1883,
78+
socket_pool=socket,
79+
ssl_context=ssl.create_default_context(),
80+
)
81+
82+
with self.assertRaises(MQTT.MMQTTException) as context:
83+
mqtt_client.loop(timeout=1)
84+
85+
assert "not connected" in str(context.exception)
86+
6687

6788
if __name__ == "__main__":
6889
main()

0 commit comments

Comments
 (0)