Skip to content

exponential backoff for (re)connect #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 160 additions & 37 deletions adafruit_minimqtt/adafruit_minimqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class MMQTTException(Exception):
# pass


class TemporaryError(Exception):
"""Temporary error class used for handling reconnects."""


# Legacy ESP32SPI Socket API
def set_socket(sock, iface=None):
"""Legacy API for setting the socket and network interface.
Expand Down Expand Up @@ -137,12 +141,13 @@ class MQTT:
:param bool use_binary_mode: Messages are passed as bytearray instead of string to callbacks.
:param int socket_timeout: How often to check socket state for read/write/connect operations,
in seconds.
:param int connect_retries: How many times to try to connect to broker before giving up.
:param int connect_retries: How many times to try to connect to the broker before giving up
on connect or reconnect. Exponential backoff will be used for the retries.
:param class user_data: arbitrary data to pass as a second argument to the callbacks.

"""

# pylint: disable=too-many-arguments,too-many-instance-attributes, not-callable, invalid-name, no-member
# pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-statements, not-callable, invalid-name, no-member
def __init__(
self,
*,
Expand Down Expand Up @@ -174,7 +179,6 @@ def __init__(
)
self._socket_timeout = socket_timeout
self._recv_timeout = recv_timeout
self._connect_retries = connect_retries

self.keep_alive = keep_alive
self._user_data = user_data
Expand All @@ -184,6 +188,13 @@ def __init__(
self._timestamp = 0
self.logger = None

self._reconnect_attempt = 0
self._reconnect_timeout = float(0)
self._reconnect_maximum_backoff = 32
if connect_retries <= 0:
raise MMQTTException("connect_retries must be positive")
self._reconnect_attempts_max = connect_retries

self.broker = broker
self._username = username
self._password = password
Expand Down Expand Up @@ -268,39 +279,37 @@ def _get_connect_socket(self, host, port, *, timeout=1):
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]

sock = None
retry_count = 0
last_exception = None
while retry_count < self._connect_retries and sock is None:
retry_count += 1
try:
sock = self._socket_pool.socket(addr_info[0], addr_info[1])
except OSError as exc:
# Do not consider this for back-off.
if self.logger is not None:
self.logger.warning(
f"Failed to create socket for host {addr_info[0]} and port {addr_info[1]}"
)
raise TemporaryError from exc

try:
sock = self._socket_pool.socket(addr_info[0], addr_info[1])
except OSError:
continue
connect_host = addr_info[-1][0]
if port == MQTT_TLS_PORT:
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout)

connect_host = addr_info[-1][0]
if port == MQTT_TLS_PORT:
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout)
last_exception = None
try:
sock.connect((connect_host, port))
except MemoryError as exc:
sock.close()
if self.logger is not None:
self.logger.warning(f"Failed to allocate memory for connect: {exc}")
# Do not consider this for back-off.
raise TemporaryError from exc
except OSError as exc:
sock.close()
last_exception = exc

try:
sock.connect((connect_host, port))
except MemoryError as exc:
sock.close()
sock = None
last_exception = exc
except OSError as exc:
sock.close()
sock = None
last_exception = exc

if sock is None:
if last_exception:
raise RuntimeError("Repeated socket failures") from last_exception

raise RuntimeError("Repeated socket failures")
if last_exception:
raise last_exception

self._backwards_compatible_sock = not hasattr(sock, "recv_into")
return sock
Expand Down Expand Up @@ -418,8 +427,66 @@ def username_pw_set(self, username, password=None):
if password is not None:
self._password = password

# pylint: disable=too-many-branches, too-many-statements, too-many-locals
def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
"""Initiates connection with the MQTT Broker. Will perform exponential back-off
on connect failures.

:param bool clean_session: Establishes a persistent session.
:param str host: Hostname or IP address of the remote broker.
:param int port: Network port of the remote broker.
:param int keep_alive: Maximum period allowed for communication
within single connection attempt, in seconds.

"""

last_exception = None
backoff = False
for i in range(0, self._reconnect_attempts_max):
if i > 0:
if backoff:
self._recompute_reconnect_backoff()
else:
self._reset_reconnect_backoff()
if self.logger is not None:
self.logger.debug(
f"Attempting to connect to MQTT broker (attempt #{self._reconnect_attempt})"
)

try:
ret = self._connect(
clean_session=clean_session,
host=host,
port=port,
keep_alive=keep_alive,
)
self._reset_reconnect_backoff()
return ret
except TemporaryError as e:
if self.logger is not None:
self.logger.warning(f"temporary error when connecting: {e}")
backoff = False
except OSError as e:
last_exception = e
if self.logger is not None:
self.logger.info(f"failed to connect: {e}")
backoff = True
except MMQTTException as e:
last_exception = e
if self.logger is not None:
self.logger.info(f"MMQT error: {e}")
backoff = True

if self._reconnect_attempts_max > 1:
exc_msg = "Repeated connect failures"
else:
exc_msg = "Connect failure"
if last_exception:
raise MMQTTException(exc_msg) from last_exception

raise MMQTTException(exc_msg)

# pylint: disable=too-many-branches, too-many-statements, too-many-locals
def _connect(self, clean_session=True, host=None, port=None, keep_alive=None):
"""Initiates connection with the MQTT Broker.

:param bool clean_session: Establishes a persistent session.
Expand All @@ -438,6 +505,12 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
if self.logger is not None:
self.logger.debug("Attempting to establish MQTT connection...")

if self._reconnect_attempt > 0:
self.logger.debug(
f"Sleeping for {self._reconnect_timeout:.3} seconds due to connect back-off"
)
time.sleep(self._reconnect_timeout)

# Get a new socket
self._sock = self._get_connect_socket(
self.broker, self.port, timeout=self._socket_timeout
Expand Down Expand Up @@ -492,7 +565,7 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
fixed_header.append(0x00)

if self.logger is not None:
self.logger.debug("Sending CONNECT to broker...")
self.logger.debug("Sending CONNECT packet to broker...")
self.logger.debug(
"Fixed Header: %s\nVariable Header: %s", fixed_header, var_header
)
Expand Down Expand Up @@ -521,6 +594,7 @@ def connect(self, clean_session=True, host=None, port=None, keep_alive=None):
result = rc[0] & 1
if self.on_connect is not None:
self.on_connect(self, self._user_data, result, rc[2])

return result

if op is None:
Expand Down Expand Up @@ -782,15 +856,62 @@ def unsubscribe(self, topic):
f"No data received from broker for {self._recv_timeout} seconds."
)

def _recompute_reconnect_backoff(self):
"""
Recompute the reconnection timeout. The self._reconnect_timeout will be used
in self._connect() to perform the actual sleep.

"""
self._reconnect_attempt = self._reconnect_attempt + 1
self._reconnect_timeout = 2**self._reconnect_attempt
if self.logger is not None:
# pylint: disable=consider-using-f-string
self.logger.debug(
"Reconnect timeout computed to {:.2f}".format(self._reconnect_timeout)
)

if self._reconnect_timeout > self._reconnect_maximum_backoff:
if self.logger is not None:
self.logger.debug(
f"Truncating reconnect timeout to {self._reconnect_maximum_backoff} seconds"
)
self._reconnect_timeout = float(self._reconnect_maximum_backoff)

# Add a sub-second jitter.
# Even truncated timeout should have jitter added to it. This is why it is added here.
jitter = randint(0, 1000) / 1000
if self.logger is not None:
# pylint: disable=consider-using-f-string
self.logger.debug(
"adding jitter {:.2f} to {:.2f} seconds".format(
jitter, self._reconnect_timeout
)
)
self._reconnect_timeout += jitter

def _reset_reconnect_backoff(self):
"""
Reset reconnect back-off to the initial state.

"""
if self.logger is not None:
self.logger.debug("Resetting reconnect backoff")
self._reconnect_attempt = 0
self._reconnect_timeout = float(0)

def reconnect(self, resub_topics=True):
"""Attempts to reconnect to the MQTT broker.
Return the value from connect() if successful. Will disconnect first if already connected.
Will perform exponential back-off on connect failures.

:param bool resub_topics: Resubscribe to previously subscribed topics.
:param bool resub_topics: Whether to resubscribe to previously subscribed topics.

"""

if self.logger is not None:
self.logger.debug("Attempting to reconnect with MQTT broker")
self.connect()

ret = self.connect()
if self.logger is not None:
self.logger.debug("Reconnected with broker")
if resub_topics:
Expand All @@ -804,6 +925,8 @@ def reconnect(self, resub_topics=True):
feed = subscribed_topics.pop()
self.subscribe(feed)

return ret

def loop(self, timeout=0):
# pylint: disable = too-many-return-statements
"""Non-blocking message loop. Use this method to
Expand Down