|
| 1 | +# SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries |
| 2 | +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries |
| 3 | +# |
| 4 | +# SPDX-License-Identifier: MIT |
| 5 | +""" |
| 6 | +`adafruit_connection_manager` |
| 7 | +================================================================================ |
| 8 | +
|
| 9 | +A urllib3.poolmanager/urllib3.connectionpool-like library for managing sockets and connections |
| 10 | +
|
| 11 | +
|
| 12 | +* Author(s): Justin Myers |
| 13 | +
|
| 14 | +Implementation Notes |
| 15 | +-------------------- |
| 16 | +
|
| 17 | +**Software and Dependencies:** |
| 18 | +
|
| 19 | +* Adafruit CircuitPython firmware for the supported boards: |
| 20 | + https://circuitpython.org/downloads |
| 21 | +
|
| 22 | +""" |
| 23 | + |
| 24 | +# imports |
| 25 | + |
| 26 | +__version__ = "0.0.0+auto.0" |
| 27 | +__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git" |
| 28 | + |
| 29 | +import errno |
| 30 | +import sys |
| 31 | + |
| 32 | +# typing |
| 33 | + |
| 34 | + |
| 35 | +if not sys.implementation.name == "circuitpython": |
| 36 | + from typing import Optional, Tuple |
| 37 | + |
| 38 | + from circuitpython_typing.socket import ( |
| 39 | + CircuitPythonSocketType, |
| 40 | + InterfaceType, |
| 41 | + SocketpoolModuleType, |
| 42 | + SocketType, |
| 43 | + SSLContextType, |
| 44 | + ) |
| 45 | + |
| 46 | + |
| 47 | +# ssl and pool helpers |
| 48 | + |
| 49 | + |
| 50 | +class _FakeSSLSocket: |
| 51 | + def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: |
| 52 | + self._socket = socket |
| 53 | + self._mode = tls_mode |
| 54 | + self.settimeout = socket.settimeout |
| 55 | + self.send = socket.send |
| 56 | + self.recv = socket.recv |
| 57 | + self.close = socket.close |
| 58 | + self.recv_into = socket.recv_into |
| 59 | + |
| 60 | + def connect(self, address: Tuple[str, int]) -> None: |
| 61 | + """Connect wrapper to add non-standard mode parameter""" |
| 62 | + try: |
| 63 | + return self._socket.connect(address, self._mode) |
| 64 | + except RuntimeError as error: |
| 65 | + raise OSError(errno.ENOMEM) from error |
| 66 | + |
| 67 | + |
| 68 | +class _FakeSSLContext: |
| 69 | + def __init__(self, iface: InterfaceType) -> None: |
| 70 | + self._iface = iface |
| 71 | + |
| 72 | + # pylint: disable=unused-argument |
| 73 | + def wrap_socket( |
| 74 | + self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None |
| 75 | + ) -> _FakeSSLSocket: |
| 76 | + """Return the same socket""" |
| 77 | + if hasattr(self._iface, "TLS_MODE"): |
| 78 | + return _FakeSSLSocket(socket, self._iface.TLS_MODE) |
| 79 | + |
| 80 | + raise AttributeError("This radio does not support TLS/HTTPS") |
| 81 | + |
| 82 | + |
| 83 | +def create_fake_ssl_context( |
| 84 | + socket_pool: SocketpoolModuleType, iface: InterfaceType |
| 85 | +) -> _FakeSSLContext: |
| 86 | + """Method to return a fake SSL context for when ssl isn't available to import |
| 87 | +
|
| 88 | + For example when using a: |
| 89 | +
|
| 90 | + * `Adafruit Ethernet FeatherWing <https://www.adafruit.com/product/3201>`_ |
| 91 | + * `Adafruit AirLift – ESP32 WiFi Co-Processor Breakout Board |
| 92 | + <https://www.adafruit.com/product/4201>`_ |
| 93 | + * `Adafruit AirLift FeatherWing – ESP32 WiFi Co-Processor |
| 94 | + <https://www.adafruit.com/product/4264>`_ |
| 95 | + """ |
| 96 | + socket_pool.set_interface(iface) |
| 97 | + return _FakeSSLContext(iface) |
| 98 | + |
| 99 | + |
| 100 | +_global_socketpool = {} |
| 101 | +_global_ssl_contexts = {} |
| 102 | + |
| 103 | + |
| 104 | +def get_radio_socketpool(radio): |
| 105 | + """Helper to get a socket pool for common boards |
| 106 | +
|
| 107 | + Currently supported: |
| 108 | +
|
| 109 | + * Boards with onboard WiFi (ESP32S2, ESP32S3, Pico W, etc) |
| 110 | + * Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift) |
| 111 | + * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) |
| 112 | + """ |
| 113 | + class_name = radio.__class__.__name__ |
| 114 | + if class_name not in _global_socketpool: |
| 115 | + if class_name == "Radio": |
| 116 | + import ssl # pylint: disable=import-outside-toplevel |
| 117 | + |
| 118 | + import socketpool # pylint: disable=import-outside-toplevel |
| 119 | + |
| 120 | + pool = socketpool.SocketPool(radio) |
| 121 | + ssl_context = ssl.create_default_context() |
| 122 | + |
| 123 | + elif class_name == "ESP_SPIcontrol": |
| 124 | + import adafruit_esp32spi.adafruit_esp32spi_socket as pool # pylint: disable=import-outside-toplevel |
| 125 | + |
| 126 | + ssl_context = create_fake_ssl_context(pool, radio) |
| 127 | + |
| 128 | + elif class_name == "WIZNET5K": |
| 129 | + import adafruit_wiznet5k.adafruit_wiznet5k_socket as pool # pylint: disable=import-outside-toplevel |
| 130 | + |
| 131 | + # Note: SSL/TLS connections are not supported by the Wiznet5k library at this time |
| 132 | + ssl_context = create_fake_ssl_context(pool, radio) |
| 133 | + |
| 134 | + else: |
| 135 | + raise AttributeError(f"Unsupported radio class: {class_name}") |
| 136 | + |
| 137 | + _global_socketpool[class_name] = pool |
| 138 | + _global_ssl_contexts[class_name] = ssl_context |
| 139 | + |
| 140 | + return _global_socketpool[class_name] |
| 141 | + |
| 142 | + |
| 143 | +def get_radio_ssl_context(radio): |
| 144 | + """Helper to get ssl_contexts for common boards |
| 145 | +
|
| 146 | + Currently supported: |
| 147 | +
|
| 148 | + * Boards with onboard WiFi (ESP32S2, ESP32S3, Pico W, etc) |
| 149 | + * Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift) |
| 150 | + * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) |
| 151 | + """ |
| 152 | + class_name = radio.__class__.__name__ |
| 153 | + get_radio_socketpool(radio) |
| 154 | + return _global_ssl_contexts[class_name] |
| 155 | + |
| 156 | + |
| 157 | +# main class |
| 158 | + |
| 159 | + |
| 160 | +class ConnectionManager: |
| 161 | + """Connection manager for sharing open sockets (aka connections).""" |
| 162 | + |
| 163 | + def __init__( |
| 164 | + self, |
| 165 | + socket_pool: SocketpoolModuleType, |
| 166 | + ) -> None: |
| 167 | + self._socket_pool = socket_pool |
| 168 | + # Hang onto open sockets so that we can reuse them. |
| 169 | + self._available_socket = {} |
| 170 | + self._open_sockets = {} |
| 171 | + |
| 172 | + def _free_sockets(self) -> None: |
| 173 | + available_sockets = [] |
| 174 | + for socket, free in self._available_socket.items(): |
| 175 | + if free: |
| 176 | + available_sockets.append(socket) |
| 177 | + |
| 178 | + for socket in available_sockets: |
| 179 | + self.close_socket(socket) |
| 180 | + |
| 181 | + def _get_key_for_socket(self, socket): |
| 182 | + try: |
| 183 | + return next( |
| 184 | + key for key, value in self._open_sockets.items() if value == socket |
| 185 | + ) |
| 186 | + except StopIteration: |
| 187 | + return None |
| 188 | + |
| 189 | + def close_socket(self, socket: SocketType) -> None: |
| 190 | + """Close a previously opened socket.""" |
| 191 | + if socket not in self._open_sockets.values(): |
| 192 | + raise RuntimeError("Socket not managed") |
| 193 | + key = self._get_key_for_socket(socket) |
| 194 | + socket.close() |
| 195 | + del self._available_socket[socket] |
| 196 | + del self._open_sockets[key] |
| 197 | + |
| 198 | + def free_socket(self, socket: SocketType) -> None: |
| 199 | + """Mark a previously opened socket as available so it can be reused if needed.""" |
| 200 | + if socket not in self._open_sockets.values(): |
| 201 | + raise RuntimeError("Socket not managed") |
| 202 | + self._available_socket[socket] = True |
| 203 | + |
| 204 | + # pylint: disable=too-many-branches,too-many-locals,too-many-statements |
| 205 | + def get_socket( |
| 206 | + self, |
| 207 | + host: str, |
| 208 | + port: int, |
| 209 | + proto: str, |
| 210 | + session_id: Optional[str] = None, |
| 211 | + *, |
| 212 | + timeout: float = 1, |
| 213 | + is_ssl: bool = False, |
| 214 | + ssl_context: Optional[SSLContextType] = None, |
| 215 | + ) -> CircuitPythonSocketType: |
| 216 | + """Get a new socket and connect""" |
| 217 | + if session_id: |
| 218 | + session_id = str(session_id) |
| 219 | + key = (host, port, proto, session_id) |
| 220 | + if key in self._open_sockets: |
| 221 | + socket = self._open_sockets[key] |
| 222 | + if self._available_socket[socket]: |
| 223 | + self._available_socket[socket] = False |
| 224 | + return socket |
| 225 | + |
| 226 | + raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}") |
| 227 | + |
| 228 | + if proto == "https:": |
| 229 | + is_ssl = True |
| 230 | + if is_ssl and not ssl_context: |
| 231 | + raise AttributeError( |
| 232 | + "ssl_context must be set before using adafruit_requests for https" |
| 233 | + ) |
| 234 | + |
| 235 | + addr_info = self._socket_pool.getaddrinfo( |
| 236 | + host, port, 0, self._socket_pool.SOCK_STREAM |
| 237 | + )[0] |
| 238 | + |
| 239 | + try_count = 0 |
| 240 | + socket = None |
| 241 | + last_exc = None |
| 242 | + while try_count < 2 and socket is None: |
| 243 | + try_count += 1 |
| 244 | + if try_count > 1: |
| 245 | + if any( |
| 246 | + socket |
| 247 | + for socket, free in self._available_socket.items() |
| 248 | + if free is True |
| 249 | + ): |
| 250 | + self._free_sockets() |
| 251 | + else: |
| 252 | + break |
| 253 | + |
| 254 | + try: |
| 255 | + socket = self._socket_pool.socket(addr_info[0], addr_info[1]) |
| 256 | + except OSError as exc: |
| 257 | + last_exc = exc |
| 258 | + continue |
| 259 | + except RuntimeError as exc: |
| 260 | + last_exc = exc |
| 261 | + continue |
| 262 | + |
| 263 | + if is_ssl: |
| 264 | + socket = ssl_context.wrap_socket(socket, server_hostname=host) |
| 265 | + connect_host = host |
| 266 | + else: |
| 267 | + connect_host = addr_info[-1][0] |
| 268 | + socket.settimeout(timeout) # socket read timeout |
| 269 | + |
| 270 | + try: |
| 271 | + socket.connect((connect_host, port)) |
| 272 | + except MemoryError as exc: |
| 273 | + last_exc = exc |
| 274 | + socket.close() |
| 275 | + socket = None |
| 276 | + except OSError as exc: |
| 277 | + last_exc = exc |
| 278 | + socket.close() |
| 279 | + socket = None |
| 280 | + |
| 281 | + if socket is None: |
| 282 | + raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc |
| 283 | + |
| 284 | + self._available_socket[socket] = False |
| 285 | + self._open_sockets[key] = socket |
| 286 | + return socket |
| 287 | + |
| 288 | + |
| 289 | +# global helpers |
| 290 | + |
| 291 | + |
| 292 | +_global_connection_manager = None # pylint: disable=invalid-name |
| 293 | + |
| 294 | + |
| 295 | +def get_connection_manager(socket_pool: SocketpoolModuleType) -> None: |
| 296 | + """Get the ConnectionManager singleton""" |
| 297 | + global _global_connection_manager # pylint: disable=global-statement |
| 298 | + if _global_connection_manager is None: |
| 299 | + _global_connection_manager = ConnectionManager(socket_pool) |
| 300 | + return _global_connection_manager |
0 commit comments