|
26 | 26 | * Adafruit CircuitPython firmware for the supported boards:
|
27 | 27 | https://github.com/adafruit/circuitpython/releases
|
28 | 28 |
|
| 29 | +* Adafruit's Connection Manager library: |
| 30 | + https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager |
| 31 | +
|
29 | 32 | """
|
30 | 33 | import errno
|
31 | 34 | import struct
|
32 | 35 | import time
|
33 | 36 | from random import randint
|
34 | 37 |
|
| 38 | +from adafruit_connection_manager import get_connection_manager |
| 39 | + |
35 | 40 | try:
|
36 | 41 | from typing import List, Optional, Tuple, Type, Union
|
37 | 42 | except ImportError:
|
|
82 | 87 | class MMQTTException(Exception):
|
83 | 88 | """MiniMQTT Exception class."""
|
84 | 89 |
|
85 |
| - # pylint: disable=unnecessary-pass |
86 |
| - # pass |
87 |
| - |
88 |
| - |
89 |
| -class TemporaryError(Exception): |
90 |
| - """Temporary error class used for handling reconnects.""" |
91 |
| - |
92 |
| - |
93 |
| -# Legacy ESP32SPI Socket API |
94 |
| -def set_socket(sock, iface=None) -> None: |
95 |
| - """Legacy API for setting the socket and network interface. |
96 |
| -
|
97 |
| - :param sock: socket object. |
98 |
| - :param iface: internet interface object |
99 |
| -
|
100 |
| - """ |
101 |
| - global _default_sock # pylint: disable=invalid-name, global-statement |
102 |
| - global _fake_context # pylint: disable=invalid-name, global-statement |
103 |
| - _default_sock = sock |
104 |
| - if iface: |
105 |
| - _default_sock.set_interface(iface) |
106 |
| - _fake_context = _FakeSSLContext(iface) |
107 |
| - |
108 |
| - |
109 |
| -class _FakeSSLSocket: |
110 |
| - def __init__(self, socket, tls_mode) -> None: |
111 |
| - self._socket = socket |
112 |
| - self._mode = tls_mode |
113 |
| - self.settimeout = socket.settimeout |
114 |
| - self.send = socket.send |
115 |
| - self.recv = socket.recv |
116 |
| - self.close = socket.close |
117 |
| - |
118 |
| - def connect(self, address): |
119 |
| - """connect wrapper to add non-standard mode parameter""" |
120 |
| - try: |
121 |
| - return self._socket.connect(address, self._mode) |
122 |
| - except RuntimeError as error: |
123 |
| - raise OSError(errno.ENOMEM) from error |
124 |
| - |
125 |
| - |
126 |
| -class _FakeSSLContext: |
127 |
| - def __init__(self, iface) -> None: |
128 |
| - self._iface = iface |
129 |
| - |
130 |
| - def wrap_socket(self, socket, server_hostname=None) -> _FakeSSLSocket: |
131 |
| - """Return the same socket""" |
132 |
| - # pylint: disable=unused-argument |
133 |
| - return _FakeSSLSocket(socket, self._iface.TLS_MODE) |
134 |
| - |
135 | 90 |
|
136 | 91 | class NullLogger:
|
137 | 92 | """Fake logger class that does not do anything"""
|
138 | 93 |
|
139 | 94 | # pylint: disable=unused-argument
|
140 | 95 | def nothing(self, msg: str, *args) -> None:
|
141 | 96 | """no action"""
|
142 |
| - pass |
143 | 97 |
|
144 | 98 | def __init__(self) -> None:
|
145 | 99 | for log_level in ["debug", "info", "warning", "error", "critical"]:
|
@@ -194,6 +148,7 @@ def __init__(
|
194 | 148 | user_data=None,
|
195 | 149 | use_imprecise_time: Optional[bool] = None,
|
196 | 150 | ) -> None:
|
| 151 | + self._connection_manager = get_connection_manager(socket_pool) |
197 | 152 | self._socket_pool = socket_pool
|
198 | 153 | self._ssl_context = ssl_context
|
199 | 154 | self._sock = None
|
@@ -300,75 +255,6 @@ def get_monotonic_time(self) -> float:
|
300 | 255 |
|
301 | 256 | return time.monotonic()
|
302 | 257 |
|
303 |
| - # pylint: disable=too-many-branches |
304 |
| - def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1): |
305 |
| - """Obtains a new socket and connects to a broker. |
306 |
| -
|
307 |
| - :param str host: Desired broker hostname |
308 |
| - :param int port: Desired broker port |
309 |
| - :param int timeout: Desired socket timeout, in seconds |
310 |
| - """ |
311 |
| - # For reconnections - check if we're using a socket already and close it |
312 |
| - if self._sock: |
313 |
| - self._sock.close() |
314 |
| - self._sock = None |
315 |
| - |
316 |
| - # Legacy API - use the interface's socket instead of a passed socket pool |
317 |
| - if self._socket_pool is None: |
318 |
| - self._socket_pool = _default_sock |
319 |
| - |
320 |
| - # Legacy API - fake the ssl context |
321 |
| - if self._ssl_context is None: |
322 |
| - self._ssl_context = _fake_context |
323 |
| - |
324 |
| - if not isinstance(port, int): |
325 |
| - raise RuntimeError("Port must be an integer") |
326 |
| - |
327 |
| - if self._is_ssl and not self._ssl_context: |
328 |
| - raise RuntimeError( |
329 |
| - "ssl_context must be set before using adafruit_mqtt for secure MQTT." |
330 |
| - ) |
331 |
| - |
332 |
| - if self._is_ssl: |
333 |
| - self.logger.info(f"Establishing a SECURE SSL connection to {host}:{port}") |
334 |
| - else: |
335 |
| - self.logger.info(f"Establishing an INSECURE connection to {host}:{port}") |
336 |
| - |
337 |
| - addr_info = self._socket_pool.getaddrinfo( |
338 |
| - host, port, 0, self._socket_pool.SOCK_STREAM |
339 |
| - )[0] |
340 |
| - |
341 |
| - try: |
342 |
| - sock = self._socket_pool.socket(addr_info[0], addr_info[1]) |
343 |
| - except OSError as exc: |
344 |
| - # Do not consider this for back-off. |
345 |
| - self.logger.warning( |
346 |
| - f"Failed to create socket for host {addr_info[0]} and port {addr_info[1]}" |
347 |
| - ) |
348 |
| - raise TemporaryError from exc |
349 |
| - |
350 |
| - connect_host = addr_info[-1][0] |
351 |
| - if self._is_ssl: |
352 |
| - sock = self._ssl_context.wrap_socket(sock, server_hostname=host) |
353 |
| - connect_host = host |
354 |
| - sock.settimeout(timeout) |
355 |
| - |
356 |
| - try: |
357 |
| - sock.connect((connect_host, port)) |
358 |
| - except MemoryError as exc: |
359 |
| - sock.close() |
360 |
| - self.logger.warning(f"Failed to allocate memory for connect: {exc}") |
361 |
| - # Do not consider this for back-off. |
362 |
| - raise TemporaryError from exc |
363 |
| - except OSError as exc: |
364 |
| - sock.close() |
365 |
| - self.logger.warning(f"Failed to connect: {exc}") |
366 |
| - # Do not consider this for back-off. |
367 |
| - raise TemporaryError from exc |
368 |
| - |
369 |
| - self._backwards_compatible_sock = not hasattr(sock, "recv_into") |
370 |
| - return sock |
371 |
| - |
372 | 258 | def __enter__(self):
|
373 | 259 | return self
|
374 | 260 |
|
@@ -538,8 +424,8 @@ def connect(
|
538 | 424 | )
|
539 | 425 | self._reset_reconnect_backoff()
|
540 | 426 | return ret
|
541 |
| - except TemporaryError as e: |
542 |
| - self.logger.warning(f"temporary error when connecting: {e}") |
| 427 | + except RuntimeError as e: |
| 428 | + self.logger.warning(f"Socket error when connecting: {e}") |
543 | 429 | backoff = False
|
544 | 430 | except MMQTTException as e:
|
545 | 431 | last_exception = e
|
@@ -587,9 +473,15 @@ def _connect(
|
587 | 473 | time.sleep(self._reconnect_timeout)
|
588 | 474 |
|
589 | 475 | # Get a new socket
|
590 |
| - self._sock = self._get_connect_socket( |
591 |
| - self.broker, self.port, timeout=self._socket_timeout |
| 476 | + self._sock = self._connection_manager.get_socket( |
| 477 | + self.broker, |
| 478 | + self.port, |
| 479 | + proto="mqtt:", |
| 480 | + timeout=self._socket_timeout, |
| 481 | + is_ssl=self._is_ssl, |
| 482 | + ssl_context=self._ssl_context, |
592 | 483 | )
|
| 484 | + self._backwards_compatible_sock = not hasattr(self._sock, "recv_into") |
593 | 485 |
|
594 | 486 | fixed_header = bytearray([0x10])
|
595 | 487 |
|
@@ -686,7 +578,7 @@ def disconnect(self) -> None:
|
686 | 578 | except RuntimeError as e:
|
687 | 579 | self.logger.warning(f"Unable to send DISCONNECT packet: {e}")
|
688 | 580 | self.logger.debug("Closing socket")
|
689 |
| - self._sock.close() |
| 581 | + self._connection_manager.free_socket(self._sock) |
690 | 582 | self._is_connected = False
|
691 | 583 | self._subscribed_topics = []
|
692 | 584 | self._last_msg_sent_timestamp = 0
|
|
0 commit comments