diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 9af31d9b..0635e1d8 100644 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -77,6 +77,10 @@ class MMQTTException(Exception): # pass +class TemporaryError(Exception): + """Temporary error class used for handling reconnects.""" + + # Legacy ESP32SPI Socket API def set_socket(sock, iface=None): """Legacy API for setting the socket and network interface. @@ -137,12 +141,13 @@ class MQTT: :param bool use_binary_mode: Messages are passed as bytearray instead of string to callbacks. :param int socket_timeout: How often to check socket state for read/write/connect operations, in seconds. - :param int connect_retries: How many times to try to connect to broker before giving up. + :param int connect_retries: How many times to try to connect to the broker before giving up + on connect or reconnect. Exponential backoff will be used for the retries. :param class user_data: arbitrary data to pass as a second argument to the callbacks. """ - # pylint: disable=too-many-arguments,too-many-instance-attributes, not-callable, invalid-name, no-member + # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-statements, not-callable, invalid-name, no-member def __init__( self, *, @@ -174,7 +179,6 @@ def __init__( ) self._socket_timeout = socket_timeout self._recv_timeout = recv_timeout - self._connect_retries = connect_retries self.keep_alive = keep_alive self._user_data = user_data @@ -184,6 +188,13 @@ def __init__( self._timestamp = 0 self.logger = None + self._reconnect_attempt = 0 + self._reconnect_timeout = float(0) + self._reconnect_maximum_backoff = 32 + if connect_retries <= 0: + raise MMQTTException("connect_retries must be positive") + self._reconnect_attempts_max = connect_retries + self.broker = broker self._username = username self._password = password @@ -268,39 +279,37 @@ def _get_connect_socket(self, host, port, *, timeout=1): host, port, 0, self._socket_pool.SOCK_STREAM )[0] - sock = None - retry_count = 0 - last_exception = None - while retry_count < self._connect_retries and sock is None: - retry_count += 1 + try: + sock = self._socket_pool.socket(addr_info[0], addr_info[1]) + except OSError as exc: + # Do not consider this for back-off. + if self.logger is not None: + self.logger.warning( + f"Failed to create socket for host {addr_info[0]} and port {addr_info[1]}" + ) + raise TemporaryError from exc - try: - sock = self._socket_pool.socket(addr_info[0], addr_info[1]) - except OSError: - continue + connect_host = addr_info[-1][0] + if port == MQTT_TLS_PORT: + sock = self._ssl_context.wrap_socket(sock, server_hostname=host) + connect_host = host + sock.settimeout(timeout) - connect_host = addr_info[-1][0] - if port == MQTT_TLS_PORT: - sock = self._ssl_context.wrap_socket(sock, server_hostname=host) - connect_host = host - sock.settimeout(timeout) + last_exception = None + try: + sock.connect((connect_host, port)) + except MemoryError as exc: + sock.close() + if self.logger is not None: + self.logger.warning(f"Failed to allocate memory for connect: {exc}") + # Do not consider this for back-off. + raise TemporaryError from exc + except OSError as exc: + sock.close() + last_exception = exc - try: - sock.connect((connect_host, port)) - except MemoryError as exc: - sock.close() - sock = None - last_exception = exc - except OSError as exc: - sock.close() - sock = None - last_exception = exc - - if sock is None: - if last_exception: - raise RuntimeError("Repeated socket failures") from last_exception - - raise RuntimeError("Repeated socket failures") + if last_exception: + raise last_exception self._backwards_compatible_sock = not hasattr(sock, "recv_into") return sock @@ -418,8 +427,66 @@ def username_pw_set(self, username, password=None): if password is not None: self._password = password - # pylint: disable=too-many-branches, too-many-statements, too-many-locals def connect(self, clean_session=True, host=None, port=None, keep_alive=None): + """Initiates connection with the MQTT Broker. Will perform exponential back-off + on connect failures. + + :param bool clean_session: Establishes a persistent session. + :param str host: Hostname or IP address of the remote broker. + :param int port: Network port of the remote broker. + :param int keep_alive: Maximum period allowed for communication + within single connection attempt, in seconds. + + """ + + last_exception = None + backoff = False + for i in range(0, self._reconnect_attempts_max): + if i > 0: + if backoff: + self._recompute_reconnect_backoff() + else: + self._reset_reconnect_backoff() + if self.logger is not None: + self.logger.debug( + f"Attempting to connect to MQTT broker (attempt #{self._reconnect_attempt})" + ) + + try: + ret = self._connect( + clean_session=clean_session, + host=host, + port=port, + keep_alive=keep_alive, + ) + self._reset_reconnect_backoff() + return ret + except TemporaryError as e: + if self.logger is not None: + self.logger.warning(f"temporary error when connecting: {e}") + backoff = False + except OSError as e: + last_exception = e + if self.logger is not None: + self.logger.info(f"failed to connect: {e}") + backoff = True + except MMQTTException as e: + last_exception = e + if self.logger is not None: + self.logger.info(f"MMQT error: {e}") + backoff = True + + if self._reconnect_attempts_max > 1: + exc_msg = "Repeated connect failures" + else: + exc_msg = "Connect failure" + if last_exception: + raise MMQTTException(exc_msg) from last_exception + + raise MMQTTException(exc_msg) + + # pylint: disable=too-many-branches, too-many-statements, too-many-locals + def _connect(self, clean_session=True, host=None, port=None, keep_alive=None): """Initiates connection with the MQTT Broker. :param bool clean_session: Establishes a persistent session. @@ -438,6 +505,12 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None): if self.logger is not None: self.logger.debug("Attempting to establish MQTT connection...") + if self._reconnect_attempt > 0: + self.logger.debug( + f"Sleeping for {self._reconnect_timeout:.3} seconds due to connect back-off" + ) + time.sleep(self._reconnect_timeout) + # Get a new socket self._sock = self._get_connect_socket( self.broker, self.port, timeout=self._socket_timeout @@ -492,7 +565,7 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None): fixed_header.append(0x00) if self.logger is not None: - self.logger.debug("Sending CONNECT to broker...") + self.logger.debug("Sending CONNECT packet to broker...") self.logger.debug( "Fixed Header: %s\nVariable Header: %s", fixed_header, var_header ) @@ -521,6 +594,7 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None): result = rc[0] & 1 if self.on_connect is not None: self.on_connect(self, self._user_data, result, rc[2]) + return result if op is None: @@ -782,15 +856,62 @@ def unsubscribe(self, topic): f"No data received from broker for {self._recv_timeout} seconds." ) + def _recompute_reconnect_backoff(self): + """ + Recompute the reconnection timeout. The self._reconnect_timeout will be used + in self._connect() to perform the actual sleep. + + """ + self._reconnect_attempt = self._reconnect_attempt + 1 + self._reconnect_timeout = 2**self._reconnect_attempt + if self.logger is not None: + # pylint: disable=consider-using-f-string + self.logger.debug( + "Reconnect timeout computed to {:.2f}".format(self._reconnect_timeout) + ) + + if self._reconnect_timeout > self._reconnect_maximum_backoff: + if self.logger is not None: + self.logger.debug( + f"Truncating reconnect timeout to {self._reconnect_maximum_backoff} seconds" + ) + self._reconnect_timeout = float(self._reconnect_maximum_backoff) + + # Add a sub-second jitter. + # Even truncated timeout should have jitter added to it. This is why it is added here. + jitter = randint(0, 1000) / 1000 + if self.logger is not None: + # pylint: disable=consider-using-f-string + self.logger.debug( + "adding jitter {:.2f} to {:.2f} seconds".format( + jitter, self._reconnect_timeout + ) + ) + self._reconnect_timeout += jitter + + def _reset_reconnect_backoff(self): + """ + Reset reconnect back-off to the initial state. + + """ + if self.logger is not None: + self.logger.debug("Resetting reconnect backoff") + self._reconnect_attempt = 0 + self._reconnect_timeout = float(0) + def reconnect(self, resub_topics=True): """Attempts to reconnect to the MQTT broker. + Return the value from connect() if successful. Will disconnect first if already connected. + Will perform exponential back-off on connect failures. - :param bool resub_topics: Resubscribe to previously subscribed topics. + :param bool resub_topics: Whether to resubscribe to previously subscribed topics. """ + if self.logger is not None: self.logger.debug("Attempting to reconnect with MQTT broker") - self.connect() + + ret = self.connect() if self.logger is not None: self.logger.debug("Reconnected with broker") if resub_topics: @@ -804,6 +925,8 @@ def reconnect(self, resub_topics=True): feed = subscribed_topics.pop() self.subscribe(feed) + return ret + def loop(self, timeout=0): # pylint: disable = too-many-return-statements """Non-blocking message loop. Use this method to