Skip to content

Commit 0a4f745

Browse files
authored
Merge pull request #16 from dhalbert/socket-retry
Re-try in more cases when socket cannot first be created
2 parents 2c79732 + da3cd5f commit 0a4f745

6 files changed

+68
-86
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ repos:
2323
- id: end-of-file-fixer
2424
- id: trailing-whitespace
2525
- repo: https://github.com/pycqa/pylint
26-
rev: v2.17.4
26+
rev: v3.1.0
2727
hooks:
2828
- id: pylint
2929
name: pylint (library code)

adafruit_connection_manager.py

+55-67
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
2222
"""
2323

24-
# imports
25-
2624
__version__ = "0.0.0+auto.0"
2725
__repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git"
2826

@@ -31,9 +29,6 @@
3129

3230
WIZNET5K_SSL_SUPPORT_VERSION = (9, 1)
3331

34-
# typing
35-
36-
3732
if not sys.implementation.name == "circuitpython":
3833
from typing import List, Optional, Tuple
3934

@@ -46,9 +41,6 @@
4641
)
4742

4843

49-
# ssl and pool helpers
50-
51-
5244
class _FakeSSLSocket:
5345
def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None:
5446
self._socket = socket
@@ -82,7 +74,7 @@ def wrap_socket( # pylint: disable=unused-argument
8274
if hasattr(self._iface, "TLS_MODE"):
8375
return _FakeSSLSocket(socket, self._iface.TLS_MODE)
8476

85-
raise AttributeError("This radio does not support TLS/HTTPS")
77+
raise ValueError("This radio does not support TLS/HTTPS")
8678

8779

8880
def create_fake_ssl_context(
@@ -167,7 +159,7 @@ def get_radio_socketpool(radio):
167159
ssl_context = create_fake_ssl_context(pool, radio)
168160

169161
else:
170-
raise AttributeError(f"Unsupported radio class: {class_name}")
162+
raise ValueError(f"Unsupported radio class: {class_name}")
171163

172164
_global_key_by_socketpool[pool] = key
173165
_global_socketpools[key] = pool
@@ -189,11 +181,8 @@ def get_radio_ssl_context(radio):
189181
return _global_ssl_contexts[_get_radio_hash_key(radio)]
190182

191183

192-
# main class
193-
194-
195184
class ConnectionManager:
196-
"""A library for managing sockets accross libraries."""
185+
"""A library for managing sockets across multiple hardware platforms and libraries."""
197186

198187
def __init__(
199188
self,
@@ -215,6 +204,11 @@ def _free_sockets(self, force: bool = False) -> None:
215204
for socket in open_sockets:
216205
self.close_socket(socket)
217206

207+
def _register_connected_socket(self, key, socket):
208+
"""Register a socket as managed."""
209+
self._key_by_managed_socket[socket] = key
210+
self._managed_socket_by_key[key] = socket
211+
218212
def _get_connected_socket( # pylint: disable=too-many-arguments
219213
self,
220214
addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]],
@@ -224,23 +218,24 @@ def _get_connected_socket( # pylint: disable=too-many-arguments
224218
is_ssl: bool,
225219
ssl_context: Optional[SSLContextType] = None,
226220
):
227-
try:
228-
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
229-
except (OSError, RuntimeError) as exc:
230-
return exc
221+
222+
socket = self._socket_pool.socket(addr_info[0], addr_info[1])
231223

232224
if is_ssl:
233225
socket = ssl_context.wrap_socket(socket, server_hostname=host)
234226
connect_host = host
235227
else:
236228
connect_host = addr_info[-1][0]
237-
socket.settimeout(timeout) # socket read timeout
229+
230+
# Set socket read and connect timeout.
231+
socket.settimeout(timeout)
238232

239233
try:
240234
socket.connect((connect_host, port))
241-
except (MemoryError, OSError) as exc:
235+
except (MemoryError, OSError):
236+
# If any connect problems, clean up and re-raise the problem exception.
242237
socket.close()
243-
return exc
238+
raise
244239

245240
return socket
246241

@@ -269,82 +264,78 @@ def close_socket(self, socket: SocketType) -> None:
269264
self._available_sockets.remove(socket)
270265

271266
def free_socket(self, socket: SocketType) -> None:
272-
"""Mark a managed socket as available so it can be reused."""
267+
"""Mark a managed socket as available so it can be reused. The socket is not closed."""
273268
if socket not in self._managed_socket_by_key.values():
274269
raise RuntimeError("Socket not managed")
275270
self._available_sockets.add(socket)
276271

272+
# pylint: disable=too-many-arguments
277273
def get_socket(
278274
self,
279275
host: str,
280276
port: int,
281277
proto: str,
282278
session_id: Optional[str] = None,
283279
*,
284-
timeout: float = 1,
280+
timeout: float = 1.0,
285281
is_ssl: bool = False,
286282
ssl_context: Optional[SSLContextType] = None,
287283
) -> CircuitPythonSocketType:
288284
"""
289-
Get a new socket and connect.
290-
291-
- **host** *(str)* – The host you are want to connect to: "www.adaftuit.com"
292-
- **port** *(int)* – The port you want to connect to: 80
293-
- **proto** *(str)* – The protocal you want to use: "http:"
294-
- **session_id** *(Optional[str])* – A unique Session ID, when wanting to have multiple open
295-
connections to the same host
296-
- **timeout** *(float)* – Time timeout used for connecting
297-
- **is_ssl** *(bool)* – If the connection is to be over SSL (auto set when proto is
298-
"https:")
299-
- **ssl_context** *(Optional[SSLContextType])* – The SSL context to use when making SSL
300-
requests
285+
Get a new socket and connect to the given host.
286+
287+
:param str host: host to connect to, such as ``"www.example.org"``
288+
:param int port: port to use for connection, such as ``80`` or ``443``
289+
:param str proto: connection protocol: ``"http:"``, ``"https:"``, etc.
290+
:param Optional[str]: unique session ID,
291+
used for multiple simultaneous connections to the same host
292+
:param float timeout: how long to wait to connect
293+
:param bool is_ssl: ``True`` If the connection is to be over SSL;
294+
automatically set when ``proto`` is ``"https:"``
295+
:param Optional[SSLContextType]: SSL context to use when making SSL requests
301296
"""
302297
if session_id:
303298
session_id = str(session_id)
304299
key = (host, port, proto, session_id)
300+
301+
# Do we have already have a socket available for the requested connection?
305302
if key in self._managed_socket_by_key:
306303
socket = self._managed_socket_by_key[key]
307304
if socket in self._available_sockets:
308305
self._available_sockets.remove(socket)
309306
return socket
310307

311-
raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}")
308+
raise RuntimeError(
309+
f"An existing socket is already connected to {proto}//{host}:{port}"
310+
)
312311

313312
if proto == "https:":
314313
is_ssl = True
315314
if is_ssl and not ssl_context:
316-
raise AttributeError(
317-
"ssl_context must be set before using adafruit_requests for https"
318-
)
315+
raise ValueError("ssl_context must be provided if using ssl")
319316

320317
addr_info = self._socket_pool.getaddrinfo(
321318
host, port, 0, self._socket_pool.SOCK_STREAM
322319
)[0]
323320

324-
first_exception = None
325-
result = self._get_connected_socket(
326-
addr_info, host, port, timeout, is_ssl, ssl_context
327-
)
328-
if isinstance(result, Exception):
329-
# Got an error, if there are any available sockets, free them and try again
321+
try:
322+
socket = self._get_connected_socket(
323+
addr_info, host, port, timeout, is_ssl, ssl_context
324+
)
325+
self._register_connected_socket(key, socket)
326+
return socket
327+
except (MemoryError, OSError, RuntimeError):
328+
# Could not get a new socket (or two, if SSL).
329+
# If there are any available sockets, free them all and try again.
330330
if self.available_socket_count:
331-
first_exception = result
332331
self._free_sockets()
333-
result = self._get_connected_socket(
332+
socket = self._get_connected_socket(
334333
addr_info, host, port, timeout, is_ssl, ssl_context
335334
)
336-
if isinstance(result, Exception):
337-
last_result = f", first error: {first_exception}" if first_exception else ""
338-
raise RuntimeError(
339-
f"Error connecting socket: {result}{last_result}"
340-
) from result
341-
342-
self._key_by_managed_socket[result] = key
343-
self._managed_socket_by_key[key] = result
344-
return result
345-
346-
347-
# global helpers
335+
self._register_connected_socket(key, socket)
336+
return socket
337+
# Re-raise exception if no sockets could be freed.
338+
raise
348339

349340

350341
def connection_manager_close_all(
@@ -353,10 +344,10 @@ def connection_manager_close_all(
353344
"""
354345
Close all open sockets for pool, optionally release references.
355346
356-
- **socket_pool** *(Optional[SocketpoolModuleType])* – A specifc SocketPool you want to close
357-
sockets for, leave blank for all SocketPools
358-
- **release_references** *(bool)* – Set to True if you want to also clear stored references to
359-
the SocketPool and SSL contexts
347+
:param Optional[SocketpoolModuleType] socket_pool:
348+
a specific socket pool whose sockets you want to close; ``None`` means all socket pools
349+
:param bool release_references: ``True`` if you also want the `ConnectionManager` to forget
350+
all the socket pools and SSL contexts it knows about
360351
"""
361352
if socket_pool:
362353
socket_pools = [socket_pool]
@@ -383,10 +374,7 @@ def connection_manager_close_all(
383374

384375
def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager:
385376
"""
386-
Get the ConnectionManager singleton for the given pool.
387-
388-
- **socket_pool** *(Optional[SocketpoolModuleType])* – The SocketPool you want the
389-
ConnectionManager for
377+
Get or create the ConnectionManager singleton for the given pool.
390378
"""
391379
if socket_pool not in _global_connection_managers:
392380
_global_connection_managers[socket_pool] = ConnectionManager(socket_pool)

tests/get_radio_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument
5555

5656
def test_get_radio_socketpool_unsupported():
5757
radio = mocket.MockRadio.Unsupported()
58-
with pytest.raises(AttributeError) as context:
58+
with pytest.raises(ValueError) as context:
5959
adafruit_connection_manager.get_radio_socketpool(radio)
6060
assert "Unsupported radio class" in str(context)
6161

@@ -100,7 +100,7 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument
100100

101101
def test_get_radio_ssl_context_unsupported():
102102
radio = mocket.MockRadio.Unsupported()
103-
with pytest.raises(AttributeError) as context:
103+
with pytest.raises(ValueError) as context:
104104
adafruit_connection_manager.get_radio_ssl_context(radio)
105105
assert "Unsupported radio class" in str(context)
106106

tests/get_socket_test.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_get_socket_not_flagged_free():
9191
# get a socket for the same host, should be a different one
9292
with pytest.raises(RuntimeError) as context:
9393
socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
94-
assert "Socket already connected" in str(context)
94+
assert "An existing socket is already connected" in str(context)
9595

9696

9797
def test_get_socket_os_error():
@@ -105,9 +105,8 @@ def test_get_socket_os_error():
105105
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
106106

107107
# try to get a socket that returns a OSError
108-
with pytest.raises(RuntimeError) as context:
108+
with pytest.raises(OSError):
109109
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
110-
assert "Error connecting socket: OSError" in str(context)
111110

112111

113112
def test_get_socket_runtime_error():
@@ -121,9 +120,8 @@ def test_get_socket_runtime_error():
121120
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
122121

123122
# try to get a socket that returns a RuntimeError
124-
with pytest.raises(RuntimeError) as context:
123+
with pytest.raises(RuntimeError):
125124
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
126-
assert "Error connecting socket: RuntimeError" in str(context)
127125

128126

129127
def test_get_socket_connect_memory_error():
@@ -139,9 +137,8 @@ def test_get_socket_connect_memory_error():
139137
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
140138

141139
# try to connect a socket that returns a MemoryError
142-
with pytest.raises(RuntimeError) as context:
140+
with pytest.raises(MemoryError):
143141
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
144-
assert "Error connecting socket: MemoryError" in str(context)
145142

146143

147144
def test_get_socket_connect_os_error():
@@ -157,9 +154,8 @@ def test_get_socket_connect_os_error():
157154
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
158155

159156
# try to connect a socket that returns a OSError
160-
with pytest.raises(RuntimeError) as context:
157+
with pytest.raises(OSError):
161158
connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:")
162-
assert "Error connecting socket: OSError" in str(context)
163159

164160

165161
def test_get_socket_runtime_error_ties_again_at_least_one_free():
@@ -211,9 +207,8 @@ def test_get_socket_runtime_error_ties_again_only_once():
211207
free_sockets_mock.assert_not_called()
212208

213209
# try to get a socket that returns a RuntimeError twice
214-
with pytest.raises(RuntimeError) as context:
210+
with pytest.raises(RuntimeError):
215211
connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:")
216-
assert "Error connecting socket: error 2, first error: error 1" in str(context)
217212
free_sockets_mock.assert_called_once()
218213

219214

@@ -248,8 +243,7 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument
248243
ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio)
249244
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
250245

251-
with pytest.raises(RuntimeError) as context:
246+
with pytest.raises(OSError):
252247
connection_manager.get_socket(
253248
mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context
254249
)
255-
assert "Error connecting socket: [Errno 12] RuntimeError" in str(context)

tests/protocol_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def test_get_https_no_ssl():
1818
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
1919

2020
# verify not sending in a SSL context for a HTTPS call errors
21-
with pytest.raises(AttributeError) as context:
21+
with pytest.raises(ValueError) as context:
2222
connection_manager.get_socket(mocket.MOCK_HOST_1, 443, "https:")
23-
assert "ssl_context must be set" in str(context)
23+
assert "ssl_context must be provided if using ssl" in str(context)
2424

2525

2626
def test_connect_https():

tests/ssl_context_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen
5858
connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool)
5959

6060
# verify a HTTPS call for a board without built in WiFi and SSL support errors
61-
with pytest.raises(AttributeError) as context:
61+
with pytest.raises(ValueError) as context:
6262
connection_manager.get_socket(
6363
mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context
6464
)

0 commit comments

Comments
 (0)