@@ -226,6 +226,7 @@ def __init__(
226
226
"""
227
227
if family != AF_INET :
228
228
raise RuntimeError ("Only AF_INET family supported by W5K modules." )
229
+ self ._socket_closed = False
229
230
self ._sock_type = type
230
231
self ._buffer = b""
231
232
self ._timeout = _default_socket_timeout
@@ -251,6 +252,17 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
251
252
if time .monotonic () - stamp > 1000 :
252
253
raise RuntimeError ("Failed to close socket" )
253
254
255
+ # This works around problems with using a class method as a decorator.
256
+ def _check_socket_closed (func ): # pylint: disable=no-self-argument
257
+ """Decorator to check whether the socket object has been closed."""
258
+
259
+ def wrapper (self , * args , ** kwargs ):
260
+ if self ._socket_closed : # pylint: disable=protected-access
261
+ raise RuntimeError ("The socket has been closed." )
262
+ return func (self , * args , ** kwargs ) # pylint: disable=not-callable
263
+
264
+ return wrapper
265
+
254
266
@property
255
267
def _status (self ) -> int :
256
268
"""
@@ -288,6 +300,7 @@ def _connected(self) -> bool:
288
300
self .close ()
289
301
return result
290
302
303
+ @_check_socket_closed
291
304
def getpeername (self ) -> Tuple [str , int ]:
292
305
"""
293
306
Return the remote address to which the socket is connected.
@@ -298,6 +311,7 @@ def getpeername(self) -> Tuple[str, int]:
298
311
self ._socknum
299
312
)
300
313
314
+ @_check_socket_closed
301
315
def bind (self , address : Tuple [Optional [str ], int ]) -> None :
302
316
"""
303
317
Bind the socket to address. The socket must not already be bound.
@@ -343,6 +357,7 @@ def _bind(self, address: Tuple[Optional[str], int]) -> None:
343
357
)
344
358
self ._buffer = b""
345
359
360
+ @_check_socket_closed
346
361
def listen (self , backlog : int = 0 ) -> None :
347
362
"""
348
363
Enable a server to accept connections.
@@ -354,6 +369,7 @@ def listen(self, backlog: int = 0) -> None:
354
369
_the_interface .socket_listen (self ._socknum , self ._listen_port )
355
370
self ._buffer = b""
356
371
372
+ @_check_socket_closed
357
373
def accept (
358
374
self ,
359
375
) -> Tuple [socket , Tuple [str , int ]]:
@@ -388,6 +404,7 @@ def accept(
388
404
raise RuntimeError ("Failed to open new listening socket" )
389
405
return client_sock , addr
390
406
407
+ @_check_socket_closed
391
408
def connect (self , address : Tuple [str , int ]) -> None :
392
409
"""
393
410
Connect to a remote socket at address.
@@ -407,6 +424,7 @@ def connect(self, address: Tuple[str, int]) -> None:
407
424
raise RuntimeError ("Failed to connect to host " , address [0 ])
408
425
self ._buffer = b""
409
426
427
+ @_check_socket_closed
410
428
def send (self , data : Union [bytes , bytearray ]) -> int :
411
429
"""
412
430
Send data to the socket. The socket must be connected to a remote socket.
@@ -422,6 +440,7 @@ def send(self, data: Union[bytes, bytearray]) -> int:
422
440
gc .collect ()
423
441
return bytes_sent
424
442
443
+ @_check_socket_closed
425
444
def sendto (self , data : bytearray , * flags_and_or_address : any ) -> int :
426
445
"""
427
446
Send data to the socket. The socket should not be connected to a remote socket, since the
@@ -445,6 +464,7 @@ def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
445
464
self .connect (address )
446
465
return self .send (data )
447
466
467
+ @_check_socket_closed
448
468
def recv (
449
469
# pylint: disable=too-many-branches
450
470
self ,
@@ -500,6 +520,7 @@ def _embed_recv(
500
520
gc .collect ()
501
521
return ret
502
522
523
+ @_check_socket_closed
503
524
def recvfrom (self , bufsize : int , flags : int = 0 ) -> Tuple [bytes , Tuple [str , int ]]:
504
525
"""
505
526
Receive data from the socket. The return value is a pair (bytes, address) where bytes is
@@ -520,6 +541,7 @@ def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]
520
541
),
521
542
)
522
543
544
+ @_check_socket_closed
523
545
def recv_into (self , buffer : bytearray , nbytes : int = 0 , flags : int = 0 ) -> int :
524
546
"""
525
547
Receive up to nbytes bytes from the socket, storing the data into a buffer
@@ -538,6 +560,7 @@ def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
538
560
buffer [:nbytes ] = bytes_received
539
561
return nbytes
540
562
563
+ @_check_socket_closed
541
564
def recvfrom_into (
542
565
self , buffer : bytearray , nbytes : int = 0 , flags : int = 0
543
566
) -> Tuple [int , Tuple [str , int ]]:
@@ -596,11 +619,13 @@ def _disconnect(self) -> None:
596
619
raise RuntimeError ("Socket must be a TCP socket." )
597
620
_the_interface .socket_disconnect (self ._socknum )
598
621
622
+ @_check_socket_closed
599
623
def close (self ) -> None :
600
624
"""
601
625
Mark the socket closed. Once that happens, all future operations on the socket object
602
626
will fail. The remote end will receive no more data.
603
627
"""
628
+ self ._socket_closed = True
604
629
_the_interface .socket_close (self ._socknum )
605
630
606
631
def _available (self ) -> int :
@@ -611,6 +636,7 @@ def _available(self) -> int:
611
636
"""
612
637
return _the_interface .socket_available (self ._socknum , self ._sock_type )
613
638
639
+ @_check_socket_closed
614
640
def settimeout (self , value : Optional [float ]) -> None :
615
641
"""
616
642
Set a timeout on blocking socket operations. The value argument can be a
@@ -627,6 +653,7 @@ def settimeout(self, value: Optional[float]) -> None:
627
653
else :
628
654
raise ValueError ("Timeout must be None, 0.0 or a positive numeric value." )
629
655
656
+ @_check_socket_closed
630
657
def gettimeout (self ) -> Optional [float ]:
631
658
"""
632
659
Return the timeout in seconds (float) associated with socket operations, or None if no
@@ -636,6 +663,7 @@ def gettimeout(self) -> Optional[float]:
636
663
"""
637
664
return self ._timeout
638
665
666
+ @_check_socket_closed
639
667
def setblocking (self , flag : bool ) -> None :
640
668
"""
641
669
Set blocking or non-blocking mode of the socket: if flag is false, the socket is set
@@ -658,6 +686,7 @@ def setblocking(self, flag: bool) -> None:
658
686
else :
659
687
raise TypeError ("Flag must be a boolean." )
660
688
689
+ @_check_socket_closed
661
690
def getblocking (self ) -> bool :
662
691
"""
663
692
Return True if socket is in blocking mode, False if in non-blocking.
@@ -669,16 +698,19 @@ def getblocking(self) -> bool:
669
698
return self .gettimeout () == 0
670
699
671
700
@property
701
+ @_check_socket_closed
672
702
def family (self ) -> int :
673
703
"""Socket family (always 0x03 in this implementation)."""
674
704
return 3
675
705
676
706
@property
707
+ @_check_socket_closed
677
708
def type (self ):
678
709
"""Socket type."""
679
710
return self ._sock_type
680
711
681
712
@property
713
+ @_check_socket_closed
682
714
def proto (self ):
683
715
"""Socket protocol (always 0x00 in this implementation)."""
684
716
return 0
0 commit comments