Skip to content

Commit 88f5176

Browse files
authored
Merge pull request #101 from BiffoBear/fix_socket.close()_behaviour
Fix socket.close() behaviour
2 parents c8fabd0 + 0f9f1b6 commit 88f5176

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

adafruit_wiznet5k/adafruit_wiznet5k_socket.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(
226226
"""
227227
if family != AF_INET:
228228
raise RuntimeError("Only AF_INET family supported by W5K modules.")
229+
self._socket_closed = False
229230
self._sock_type = type
230231
self._buffer = b""
231232
self._timeout = _default_socket_timeout
@@ -251,6 +252,17 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
251252
if time.monotonic() - stamp > 1000:
252253
raise RuntimeError("Failed to close socket")
253254

255+
# This works around problems with using a class method as a decorator.
256+
def _check_socket_closed(func): # pylint: disable=no-self-argument
257+
"""Decorator to check whether the socket object has been closed."""
258+
259+
def wrapper(self, *args, **kwargs):
260+
if self._socket_closed: # pylint: disable=protected-access
261+
raise RuntimeError("The socket has been closed.")
262+
return func(self, *args, **kwargs) # pylint: disable=not-callable
263+
264+
return wrapper
265+
254266
@property
255267
def _status(self) -> int:
256268
"""
@@ -288,6 +300,7 @@ def _connected(self) -> bool:
288300
self.close()
289301
return result
290302

303+
@_check_socket_closed
291304
def getpeername(self) -> Tuple[str, int]:
292305
"""
293306
Return the remote address to which the socket is connected.
@@ -298,6 +311,7 @@ def getpeername(self) -> Tuple[str, int]:
298311
self._socknum
299312
)
300313

314+
@_check_socket_closed
301315
def bind(self, address: Tuple[Optional[str], int]) -> None:
302316
"""
303317
Bind the socket to address. The socket must not already be bound.
@@ -343,6 +357,7 @@ def _bind(self, address: Tuple[Optional[str], int]) -> None:
343357
)
344358
self._buffer = b""
345359

360+
@_check_socket_closed
346361
def listen(self, backlog: int = 0) -> None:
347362
"""
348363
Enable a server to accept connections.
@@ -354,6 +369,7 @@ def listen(self, backlog: int = 0) -> None:
354369
_the_interface.socket_listen(self._socknum, self._listen_port)
355370
self._buffer = b""
356371

372+
@_check_socket_closed
357373
def accept(
358374
self,
359375
) -> Tuple[socket, Tuple[str, int]]:
@@ -388,6 +404,7 @@ def accept(
388404
raise RuntimeError("Failed to open new listening socket")
389405
return client_sock, addr
390406

407+
@_check_socket_closed
391408
def connect(self, address: Tuple[str, int]) -> None:
392409
"""
393410
Connect to a remote socket at address.
@@ -407,6 +424,7 @@ def connect(self, address: Tuple[str, int]) -> None:
407424
raise RuntimeError("Failed to connect to host ", address[0])
408425
self._buffer = b""
409426

427+
@_check_socket_closed
410428
def send(self, data: Union[bytes, bytearray]) -> int:
411429
"""
412430
Send data to the socket. The socket must be connected to a remote socket.
@@ -422,6 +440,7 @@ def send(self, data: Union[bytes, bytearray]) -> int:
422440
gc.collect()
423441
return bytes_sent
424442

443+
@_check_socket_closed
425444
def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
426445
"""
427446
Send data to the socket. The socket should not be connected to a remote socket, since the
@@ -445,6 +464,7 @@ def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
445464
self.connect(address)
446465
return self.send(data)
447466

467+
@_check_socket_closed
448468
def recv(
449469
# pylint: disable=too-many-branches
450470
self,
@@ -500,6 +520,7 @@ def _embed_recv(
500520
gc.collect()
501521
return ret
502522

523+
@_check_socket_closed
503524
def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]]:
504525
"""
505526
Receive data from the socket. The return value is a pair (bytes, address) where bytes is
@@ -520,6 +541,7 @@ def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]
520541
),
521542
)
522543

544+
@_check_socket_closed
523545
def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
524546
"""
525547
Receive up to nbytes bytes from the socket, storing the data into a buffer
@@ -538,6 +560,7 @@ def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
538560
buffer[:nbytes] = bytes_received
539561
return nbytes
540562

563+
@_check_socket_closed
541564
def recvfrom_into(
542565
self, buffer: bytearray, nbytes: int = 0, flags: int = 0
543566
) -> Tuple[int, Tuple[str, int]]:
@@ -596,11 +619,13 @@ def _disconnect(self) -> None:
596619
raise RuntimeError("Socket must be a TCP socket.")
597620
_the_interface.socket_disconnect(self._socknum)
598621

622+
@_check_socket_closed
599623
def close(self) -> None:
600624
"""
601625
Mark the socket closed. Once that happens, all future operations on the socket object
602626
will fail. The remote end will receive no more data.
603627
"""
628+
self._socket_closed = True
604629
_the_interface.socket_close(self._socknum)
605630

606631
def _available(self) -> int:
@@ -611,6 +636,7 @@ def _available(self) -> int:
611636
"""
612637
return _the_interface.socket_available(self._socknum, self._sock_type)
613638

639+
@_check_socket_closed
614640
def settimeout(self, value: Optional[float]) -> None:
615641
"""
616642
Set a timeout on blocking socket operations. The value argument can be a
@@ -627,6 +653,7 @@ def settimeout(self, value: Optional[float]) -> None:
627653
else:
628654
raise ValueError("Timeout must be None, 0.0 or a positive numeric value.")
629655

656+
@_check_socket_closed
630657
def gettimeout(self) -> Optional[float]:
631658
"""
632659
Return the timeout in seconds (float) associated with socket operations, or None if no
@@ -636,6 +663,7 @@ def gettimeout(self) -> Optional[float]:
636663
"""
637664
return self._timeout
638665

666+
@_check_socket_closed
639667
def setblocking(self, flag: bool) -> None:
640668
"""
641669
Set blocking or non-blocking mode of the socket: if flag is false, the socket is set
@@ -658,6 +686,7 @@ def setblocking(self, flag: bool) -> None:
658686
else:
659687
raise TypeError("Flag must be a boolean.")
660688

689+
@_check_socket_closed
661690
def getblocking(self) -> bool:
662691
"""
663692
Return True if socket is in blocking mode, False if in non-blocking.
@@ -669,16 +698,19 @@ def getblocking(self) -> bool:
669698
return self.gettimeout() == 0
670699

671700
@property
701+
@_check_socket_closed
672702
def family(self) -> int:
673703
"""Socket family (always 0x03 in this implementation)."""
674704
return 3
675705

676706
@property
707+
@_check_socket_closed
677708
def type(self):
678709
"""Socket type."""
679710
return self._sock_type
680711

681712
@property
713+
@_check_socket_closed
682714
def proto(self):
683715
"""Socket protocol (always 0x00 in this implementation)."""
684716
return 0

0 commit comments

Comments
 (0)