diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 6ab6db2f7d..beffba6d18 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -28,7 +28,7 @@ Union, ) -from pymongo import _csot, ssl_support +from pymongo import ssl_support from pymongo._asyncio_task import create_task from pymongo.errors import _OperationCancelled from pymongo.socket_checker import _errno_from_exception @@ -316,62 +316,47 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo return mv -# Sync version: -def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: - """Block until at least one byte is read, or a timeout, or a cancel.""" - sock = conn.conn - timed_out = False - # Check if the connection's socket has been manually closed - if sock.fileno() == -1: - return - while True: - # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, "pending") and sock.pending() > 0: - readable = True - else: - # Wait up to 500ms for the socket to become readable and then - # check for cancellation. - if deadline: - remaining = deadline - time.monotonic() - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - if remaining <= 0: - timed_out = True - timeout = max(min(remaining, _POLL_TIMEOUT), 0) - else: - timeout = _POLL_TIMEOUT - readable = conn.socket_checker.select(sock, read=True, timeout=timeout) - if conn.cancel_context.cancelled: - raise _OperationCancelled("operation cancelled") - if readable: - return - if timed_out: - raise socket.timeout("timed out") - - def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview: buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 - while bytes_read < length: - try: - wait_for_read(conn, deadline) - # CSOT: Update timeout. When the timeout has expired perform one - # final non-blocking recv. This helps avoid spurious timeouts when - # the response is actually already buffered on the client. - if _csot.get_timeout() and deadline is not None: - conn.set_conn_timeout(max(deadline - time.monotonic(), 0)) - chunk_length = conn.conn.recv_into(mv[bytes_read:]) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - except OSError as exc: - if _errno_from_exception(exc) == errno.EINTR: + # To support cancelling a network read, we shorten the socket timeout and + # check for the cancellation signal after each timeout. Alternatively we + # could close the socket but that does not reliably cancel recv() calls + # on all OSes. + orig_timeout = conn.conn.gettimeout() + try: + while bytes_read < length: + if deadline is not None: + # CSOT: Update timeout. When the timeout has expired perform one + # final non-blocking recv. This helps avoid spurious timeouts when + # the response is actually already buffered on the client. + short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT) + else: + short_timeout = _POLL_TIMEOUT + conn.set_conn_timeout(short_timeout) + try: + chunk_length = conn.conn.recv_into(mv[bytes_read:]) + except BLOCKING_IO_ERRORS: + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") from None + # We reached the true deadline. + raise socket.timeout("timed out") from None + except socket.timeout: + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") from None continue - raise - if chunk_length == 0: - raise OSError("connection closed") - - bytes_read += chunk_length + except OSError as exc: + if conn.cancel_context.cancelled: + raise _OperationCancelled("operation cancelled") from None + if _errno_from_exception(exc) == errno.EINTR: + continue + raise + if chunk_length == 0: + raise OSError("connection closed") + + bytes_read += chunk_length + finally: + conn.set_conn_timeout(orig_timeout) return mv