Skip to content

PYTHON-4292 Improve TLS read performance #2020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 38 additions & 53 deletions pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Union,
)

from pymongo import _csot, ssl_support
from pymongo import ssl_support
from pymongo._asyncio_task import create_task
from pymongo.errors import _OperationCancelled
from pymongo.socket_checker import _errno_from_exception
Expand Down Expand Up @@ -316,62 +316,47 @@ async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLo
return mv


# Sync version:
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")


def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
# To support cancelling a network read, we shorten the socket timeout and
# check for the cancellation signal after each timeout. Alternatively we
# could close the socket but that does not reliably cancel recv() calls
# on all OSes.
orig_timeout = conn.conn.gettimeout()
try:
while bytes_read < length:
if deadline is not None:
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
short_timeout = min(max(deadline - time.monotonic(), 0), _POLL_TIMEOUT)
else:
short_timeout = _POLL_TIMEOUT
conn.set_conn_timeout(short_timeout)
try:
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
# We reached the true deadline.
raise socket.timeout("timed out") from None
except socket.timeout:
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
continue
raise
if chunk_length == 0:
raise OSError("connection closed")

bytes_read += chunk_length
except OSError as exc:
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")

bytes_read += chunk_length
finally:
conn.set_conn_timeout(orig_timeout)

return mv
Loading