Skip to content

Commit d282022

Browse files
committed
Merge branch 'main' into fix_out_of_sockets
# Conflicts: # adafruit_wiznet5k/adafruit_wiznet5k_socket.py
2 parents 3d96b1c + 88f5176 commit d282022

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

adafruit_wiznet5k/adafruit_wiznet5k_socket.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(
226226
"""
227227
if family != AF_INET:
228228
raise RuntimeError("Only AF_INET family supported by W5K modules.")
229+
self._socket_closed = False
229230
self._sock_type = type
230231
self._buffer = b""
231232
self._timeout = _default_socket_timeout
@@ -255,6 +256,17 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
255256
if time.monotonic() - stamp > 1000:
256257
raise RuntimeError("Failed to close socket")
257258

259+
# This works around problems with using a class method as a decorator.
260+
def _check_socket_closed(func): # pylint: disable=no-self-argument
261+
"""Decorator to check whether the socket object has been closed."""
262+
263+
def wrapper(self, *args, **kwargs):
264+
if self._socket_closed: # pylint: disable=protected-access
265+
raise RuntimeError("The socket has been closed.")
266+
return func(self, *args, **kwargs) # pylint: disable=not-callable
267+
268+
return wrapper
269+
258270
@property
259271
def _status(self) -> int:
260272
"""
@@ -292,6 +304,7 @@ def _connected(self) -> bool:
292304
self.close()
293305
return result
294306

307+
@_check_socket_closed
295308
def getpeername(self) -> Tuple[str, int]:
296309
"""
297310
Return the remote address to which the socket is connected.
@@ -302,6 +315,7 @@ def getpeername(self) -> Tuple[str, int]:
302315
self._socknum
303316
)
304317

318+
@_check_socket_closed
305319
def bind(self, address: Tuple[Optional[str], int]) -> None:
306320
"""
307321
Bind the socket to address. The socket must not already be bound.
@@ -347,6 +361,7 @@ def _bind(self, address: Tuple[Optional[str], int]) -> None:
347361
)
348362
self._buffer = b""
349363

364+
@_check_socket_closed
350365
def listen(self, backlog: int = 0) -> None:
351366
"""
352367
Enable a server to accept connections.
@@ -358,6 +373,7 @@ def listen(self, backlog: int = 0) -> None:
358373
_the_interface.socket_listen(self._socknum, self._listen_port)
359374
self._buffer = b""
360375

376+
@_check_socket_closed
361377
def accept(
362378
self,
363379
) -> Tuple[socket, Tuple[str, int]]:
@@ -392,6 +408,7 @@ def accept(
392408
raise RuntimeError("Failed to open new listening socket")
393409
return client_sock, addr
394410

411+
@_check_socket_closed
395412
def connect(self, address: Tuple[str, int]) -> None:
396413
"""
397414
Connect to a remote socket at address.
@@ -411,6 +428,7 @@ def connect(self, address: Tuple[str, int]) -> None:
411428
raise RuntimeError("Failed to connect to host ", address[0])
412429
self._buffer = b""
413430

431+
@_check_socket_closed
414432
def send(self, data: Union[bytes, bytearray]) -> int:
415433
"""
416434
Send data to the socket. The socket must be connected to a remote socket.
@@ -426,6 +444,7 @@ def send(self, data: Union[bytes, bytearray]) -> int:
426444
gc.collect()
427445
return bytes_sent
428446

447+
@_check_socket_closed
429448
def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
430449
"""
431450
Send data to the socket. The socket should not be connected to a remote socket, since the
@@ -449,6 +468,7 @@ def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
449468
self.connect(address)
450469
return self.send(data)
451470

471+
@_check_socket_closed
452472
def recv(
453473
# pylint: disable=too-many-branches
454474
self,
@@ -504,6 +524,7 @@ def _embed_recv(
504524
gc.collect()
505525
return ret
506526

527+
@_check_socket_closed
507528
def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]]:
508529
"""
509530
Receive data from the socket. The return value is a pair (bytes, address) where bytes is
@@ -524,6 +545,7 @@ def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]
524545
),
525546
)
526547

548+
@_check_socket_closed
527549
def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
528550
"""
529551
Receive up to nbytes bytes from the socket, storing the data into a buffer
@@ -542,6 +564,7 @@ def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
542564
buffer[:nbytes] = bytes_received
543565
return nbytes
544566

567+
@_check_socket_closed
545568
def recvfrom_into(
546569
self, buffer: bytearray, nbytes: int = 0, flags: int = 0
547570
) -> Tuple[int, Tuple[str, int]]:
@@ -600,13 +623,15 @@ def _disconnect(self) -> None:
600623
raise RuntimeError("Socket must be a TCP socket.")
601624
_the_interface.socket_disconnect(self._socknum)
602625

626+
@_check_socket_closed
603627
def close(self) -> None:
604628
"""
605629
Mark the socket closed. Once that happens, all future operations on the socket object
606630
will fail. The remote end will receive no more data.
607631
"""
608632
_the_interface.release_socket(self._socknum)
609633
_the_interface.socket_close(self._socknum)
634+
self._socket_closed = True
610635

611636
def _available(self) -> int:
612637
"""
@@ -616,6 +641,7 @@ def _available(self) -> int:
616641
"""
617642
return _the_interface.socket_available(self._socknum, self._sock_type)
618643

644+
@_check_socket_closed
619645
def settimeout(self, value: Optional[float]) -> None:
620646
"""
621647
Set a timeout on blocking socket operations. The value argument can be a
@@ -632,6 +658,7 @@ def settimeout(self, value: Optional[float]) -> None:
632658
else:
633659
raise ValueError("Timeout must be None, 0.0 or a positive numeric value.")
634660

661+
@_check_socket_closed
635662
def gettimeout(self) -> Optional[float]:
636663
"""
637664
Return the timeout in seconds (float) associated with socket operations, or None if no
@@ -641,6 +668,7 @@ def gettimeout(self) -> Optional[float]:
641668
"""
642669
return self._timeout
643670

671+
@_check_socket_closed
644672
def setblocking(self, flag: bool) -> None:
645673
"""
646674
Set blocking or non-blocking mode of the socket: if flag is false, the socket is set
@@ -663,6 +691,7 @@ def setblocking(self, flag: bool) -> None:
663691
else:
664692
raise TypeError("Flag must be a boolean.")
665693

694+
@_check_socket_closed
666695
def getblocking(self) -> bool:
667696
"""
668697
Return True if socket is in blocking mode, False if in non-blocking.
@@ -674,16 +703,19 @@ def getblocking(self) -> bool:
674703
return self.gettimeout() == 0
675704

676705
@property
706+
@_check_socket_closed
677707
def family(self) -> int:
678708
"""Socket family (always 0x03 in this implementation)."""
679709
return 3
680710

681711
@property
712+
@_check_socket_closed
682713
def type(self):
683714
"""Socket type."""
684715
return self._sock_type
685716

686717
@property
718+
@_check_socket_closed
687719
def proto(self):
688720
"""Socket protocol (always 0x00 in this implementation)."""
689721
return 0

0 commit comments

Comments
 (0)