Skip to content

address _wait_for_msg() nits #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,13 @@
# MQTT Commands
MQTT_PINGREQ = b"\xc0\0"
MQTT_PINGRESP = const(0xD0)
MQTT_PUBLISH = const(0x30)
MQTT_SUB = b"\x82"
MQTT_UNSUB = b"\xA2"
MQTT_DISCONNECT = b"\xe0\0"

MQTT_PKT_TYPE_MASK = const(0xF0)

# Variable CONNECT header [MQTT 3.1.2]
MQTT_HDR_CONNECT = bytearray(b"\x04MQTT\x04\x02\0\0")

Expand Down Expand Up @@ -210,7 +213,6 @@ def __init__(
# LWT
self._lw_topic = None
self._lw_qos = 0
self._lw_topic = None
self._lw_msg = None
self._lw_retain = False

Expand Down Expand Up @@ -630,7 +632,7 @@ def publish(self, topic, msg, retain=False, qos=0):
), "Quality of Service Level 2 is unsupported by this library."

# fixed header. [3.3.1.2], [3.3.1.3]
pub_hdr_fixed = bytearray([0x30 | retain | qos << 1])
pub_hdr_fixed = bytearray([MQTT_PUBLISH | retain | qos << 1])

# variable header = 2-byte Topic length (big endian)
pub_hdr_var = bytearray(struct.pack(">H", len(topic.encode("utf-8"))))
Expand Down Expand Up @@ -879,7 +881,9 @@ def loop(self, timeout=0):
def _wait_for_msg(self, timeout=0.1):
# pylint: disable = too-many-return-statements

"""Reads and processes network events."""
"""Reads and processes network events.
Return the packet type or None if there is nothing to be received.
"""
# CPython socket module contains a timeout attribute
if hasattr(self._socket_pool, "timeout"):
try:
Expand All @@ -900,7 +904,7 @@ def _wait_for_msg(self, timeout=0.1):
if res in [None, b"", b"\x00"]:
# If we get here, it means that there is nothing to be received
return None
if res[0] == MQTT_PINGRESP:
if res[0] & MQTT_PKT_TYPE_MASK == MQTT_PINGRESP:
if self.logger is not None:
self.logger.debug("Got PINGRESP")
sz = self._sock_exact_recv(1)[0]
Expand All @@ -909,12 +913,21 @@ def _wait_for_msg(self, timeout=0.1):
"Unexpected PINGRESP returned from broker: {}.".format(sz)
)
return MQTT_PINGRESP
if res[0] & 0xF0 != 0x30:

if res[0] & MQTT_PKT_TYPE_MASK != MQTT_PUBLISH:
return res[0]

# Handle only the PUBLISH packet type from now on.
sz = self._recv_len()
# topic length MSB & LSB
topic_len = self._sock_exact_recv(2)
topic_len = (topic_len[0] << 8) | topic_len[1]

if topic_len > sz - 2:
raise MMQTTException(
f"Topic length {topic_len} in PUBLISH packet exceeds remaining length {sz} - 2"
)

topic = self._sock_exact_recv(topic_len)
topic = str(topic, "utf-8")
sz -= topic_len + 2
Expand All @@ -923,12 +936,13 @@ def _wait_for_msg(self, timeout=0.1):
pid = self._sock_exact_recv(2)
pid = pid[0] << 0x08 | pid[1]
sz -= 0x02

# read message contents
raw_msg = self._sock_exact_recv(sz)
msg = raw_msg if self._use_binary_mode else str(raw_msg, "utf-8")
if self.logger is not None:
self.logger.debug(
"Receiving SUBSCRIBE \nTopic: %s\nMsg: %s\n", topic, raw_msg
"Receiving PUBLISH \nTopic: %s\nMsg: %s\n", topic, raw_msg
)
self._handle_on_message(self, topic, msg)
if res[0] & 0x06 == 0x02:
Expand All @@ -937,6 +951,7 @@ def _wait_for_msg(self, timeout=0.1):
self._sock.send(pkt)
elif res[0] & 6 == 4:
assert 0

return res[0]

def _recv_len(self):
Expand Down