Skip to content

Commit 4fc53b8

Browse files
author
BiffoBear
committed
Added a decorator to raise a RuntimeError if method are called after close() has been called.
1 parent 0144787 commit 4fc53b8

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

adafruit_wiznet5k/adafruit_wiznet5k_socket.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(
226226
"""
227227
if family != AF_INET:
228228
raise RuntimeError("Only AF_INET family supported by W5K modules.")
229+
self._socket_closed = False
229230
self._sock_type = type
230231
self._buffer = b""
231232
self._timeout = _default_socket_timeout
@@ -251,6 +252,19 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
251252
if time.monotonic() - stamp > 1000:
252253
raise RuntimeError("Failed to close socket")
253254

255+
# This works around problems with using a method as a decorator.
256+
def _check_socket_closed(func): # pylint: disable=no-self-argument
257+
"""Decorator to check whether the socket object has been closed."""
258+
259+
def wrapper(self, *args, **kwargs):
260+
if self._socket_closed: # pylint: disable=protected-access
261+
raise RuntimeError("The socket has been closed.")
262+
print(*args)
263+
print(**kwargs)
264+
return func(self, *args, **kwargs) # pylint: disable=not-callable
265+
266+
return wrapper
267+
254268
@property
255269
def _status(self) -> int:
256270
"""
@@ -288,6 +302,7 @@ def _connected(self) -> bool:
288302
self.close()
289303
return result
290304

305+
@_check_socket_closed
291306
def getpeername(self) -> Tuple[str, int]:
292307
"""
293308
Return the remote address to which the socket is connected.
@@ -298,6 +313,7 @@ def getpeername(self) -> Tuple[str, int]:
298313
self._socknum
299314
)
300315

316+
@_check_socket_closed
301317
def bind(self, address: Tuple[Optional[str], int]) -> None:
302318
"""
303319
Bind the socket to address. The socket must not already be bound.
@@ -343,6 +359,7 @@ def _bind(self, address: Tuple[Optional[str], int]) -> None:
343359
)
344360
self._buffer = b""
345361

362+
@_check_socket_closed
346363
def listen(self, backlog: int = 0) -> None:
347364
"""
348365
Enable a server to accept connections.
@@ -354,6 +371,7 @@ def listen(self, backlog: int = 0) -> None:
354371
_the_interface.socket_listen(self._socknum, self._listen_port)
355372
self._buffer = b""
356373

374+
@_check_socket_closed
357375
def accept(
358376
self,
359377
) -> Tuple[socket, Tuple[str, int]]:
@@ -388,6 +406,7 @@ def accept(
388406
raise RuntimeError("Failed to open new listening socket")
389407
return client_sock, addr
390408

409+
@_check_socket_closed
391410
def connect(self, address: Tuple[str, int]) -> None:
392411
"""
393412
Connect to a remote socket at address.
@@ -407,6 +426,7 @@ def connect(self, address: Tuple[str, int]) -> None:
407426
raise RuntimeError("Failed to connect to host ", address[0])
408427
self._buffer = b""
409428

429+
@_check_socket_closed
410430
def send(self, data: Union[bytes, bytearray]) -> int:
411431
"""
412432
Send data to the socket. The socket must be connected to a remote socket.
@@ -422,6 +442,7 @@ def send(self, data: Union[bytes, bytearray]) -> int:
422442
gc.collect()
423443
return bytes_sent
424444

445+
@_check_socket_closed
425446
def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
426447
"""
427448
Send data to the socket. The socket should not be connected to a remote socket, since the
@@ -445,6 +466,7 @@ def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
445466
self.connect(address)
446467
return self.send(data)
447468

469+
@_check_socket_closed
448470
def recv(
449471
# pylint: disable=too-many-branches
450472
self,
@@ -500,6 +522,7 @@ def _embed_recv(
500522
gc.collect()
501523
return ret
502524

525+
@_check_socket_closed
503526
def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]]:
504527
"""
505528
Receive data from the socket. The return value is a pair (bytes, address) where bytes is
@@ -520,6 +543,7 @@ def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]
520543
),
521544
)
522545

546+
@_check_socket_closed
523547
def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
524548
"""
525549
Receive up to nbytes bytes from the socket, storing the data into a buffer
@@ -538,6 +562,7 @@ def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
538562
buffer[:nbytes] = bytes_received
539563
return nbytes
540564

565+
@_check_socket_closed
541566
def recvfrom_into(
542567
self, buffer: bytearray, nbytes: int = 0, flags: int = 0
543568
) -> Tuple[int, Tuple[str, int]]:
@@ -596,11 +621,13 @@ def _disconnect(self) -> None:
596621
raise RuntimeError("Socket must be a TCP socket.")
597622
_the_interface.socket_disconnect(self._socknum)
598623

624+
@_check_socket_closed
599625
def close(self) -> None:
600626
"""
601627
Mark the socket closed. Once that happens, all future operations on the socket object
602628
will fail. The remote end will receive no more data.
603629
"""
630+
self._socket_closed = True
604631
_the_interface.socket_close(self._socknum)
605632

606633
def _available(self) -> int:
@@ -611,6 +638,7 @@ def _available(self) -> int:
611638
"""
612639
return _the_interface.socket_available(self._socknum, self._sock_type)
613640

641+
@_check_socket_closed
614642
def settimeout(self, value: Optional[float]) -> None:
615643
"""
616644
Set a timeout on blocking socket operations. The value argument can be a
@@ -627,6 +655,7 @@ def settimeout(self, value: Optional[float]) -> None:
627655
else:
628656
raise ValueError("Timeout must be None, 0.0 or a positive numeric value.")
629657

658+
@_check_socket_closed
630659
def gettimeout(self) -> Optional[float]:
631660
"""
632661
Return the timeout in seconds (float) associated with socket operations, or None if no
@@ -636,6 +665,7 @@ def gettimeout(self) -> Optional[float]:
636665
"""
637666
return self._timeout
638667

668+
@_check_socket_closed
639669
def setblocking(self, flag: bool) -> None:
640670
"""
641671
Set blocking or non-blocking mode of the socket: if flag is false, the socket is set
@@ -658,6 +688,7 @@ def setblocking(self, flag: bool) -> None:
658688
else:
659689
raise TypeError("Flag must be a boolean.")
660690

691+
@_check_socket_closed
661692
def getblocking(self) -> bool:
662693
"""
663694
Return True if socket is in blocking mode, False if in non-blocking.
@@ -668,16 +699,19 @@ def getblocking(self) -> bool:
668699
"""
669700
return self.gettimeout() == 0
670701

702+
@_check_socket_closed
671703
@property
672704
def family(self) -> int:
673705
"""Socket family (always 0x03 in this implementation)."""
674706
return 3
675707

708+
@_check_socket_closed
676709
@property
677710
def type(self):
678711
"""Socket type."""
679712
return self._sock_type
680713

714+
@_check_socket_closed
681715
@property
682716
def proto(self):
683717
"""Socket protocol (always 0x00 in this implementation)."""

0 commit comments

Comments
 (0)