Skip to content

PYTHON-4636 - Avoid blocking I/O calls in async code paths #1870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 5 additions & 76 deletions pymongo/asynchronous/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
"""Internal network layer helper methods."""
from __future__ import annotations

import asyncio
import datetime
import errno
import logging
import socket
import time
from typing import (
TYPE_CHECKING,
Expand All @@ -40,19 +37,16 @@
NotPrimaryError,
OperationFailure,
ProtocolError,
_OperationCancelled,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_POLL_TIMEOUT,
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
BLOCKING_IO_ERRORS,
async_receive_data,
async_sendall,
)
from pymongo.socket_checker import _errno_from_exception

if TYPE_CHECKING:
from bson import CodecOptions
Expand Down Expand Up @@ -318,9 +312,7 @@ async def receive_message(
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(
await _receive_data_on_socket(conn, 16, deadline)
)
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
Expand All @@ -336,11 +328,11 @@ async def receive_message(
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
await _receive_data_on_socket(conn, 9, deadline)
await async_receive_data(conn, 9, deadline)
)
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
else:
data = await _receive_data_on_socket(conn, length - 16, deadline)
data = await async_receive_data(conn, length - 16, deadline)

try:
unpack_reply = _UNPACK_REPLY[op_code]
Expand All @@ -349,66 +341,3 @@ async def receive_message(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)


async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")
await asyncio.sleep(0)


async def _receive_data_on_socket(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
await wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")

bytes_read += chunk_length

return mv
187 changes: 186 additions & 1 deletion pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
from __future__ import annotations

import asyncio
import errno
import socket
import struct
import sys
import time
from asyncio import AbstractEventLoop, Future
from typing import (
TYPE_CHECKING,
Optional,
Union,
)

from pymongo import ssl_support
from pymongo import _csot, ssl_support
from pymongo.errors import _OperationCancelled
from pymongo.socket_checker import _errno_from_exception

try:
from ssl import SSLError, SSLSocket
Expand All @@ -51,6 +57,10 @@
BLOCKING_IO_WRITE_ERROR,
)

if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.synchronous.pool import Connection

_UNPACK_HEADER = struct.Struct("<iiii").unpack
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
_POLL_TIMEOUT = 0.5
Expand Down Expand Up @@ -111,6 +121,47 @@ def _is_ready(fut: Future) -> None:
loop.add_reader(fd, _is_ready, fut)
loop.add_writer(fd, _is_ready, fut)
await fut

async def _async_receive_ssl(
conn: _sslConn, length: int, loop: AbstractEventLoop
) -> memoryview:
mv = memoryview(bytearray(length))
fd = conn.fileno()
total_read = 0

def _is_ready(fut: Future) -> None:
loop.remove_writer(fd)
loop.remove_reader(fd)
if fut.done():
return
fut.set_result(None)

while total_read < length:
try:
read = conn.recv_into(mv[total_read:])
if read == 0:
raise OSError("connection closed")
total_read += read
except BLOCKING_IO_ERRORS as exc:
fd = conn.fileno()
# Check for closed socket.
if fd == -1:
raise SSLError("Underlying socket has been closed") from None
if isinstance(exc, BLOCKING_IO_READ_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
await fut
if isinstance(exc, BLOCKING_IO_WRITE_ERROR):
fut = loop.create_future()
loop.add_writer(fd, _is_ready, fut)
await fut
if _HAVE_PYOPENSSL and isinstance(exc, BLOCKING_IO_LOOKUP_ERROR):
fut = loop.create_future()
loop.add_reader(fd, _is_ready, fut)
loop.add_writer(fd, _is_ready, fut)
await fut
return mv

else:
# The default Windows asyncio event loop does not support loop.add_reader/add_writer:
# https://docs.python.org/3/library/asyncio-platforms.html#asyncio-platform-support
Expand All @@ -128,6 +179,140 @@ async def _async_sendall_ssl(
sent = 0
total_sent += sent

async def _async_receive_ssl(
conn: _sslConn, length: int, dummy: AbstractEventLoop
) -> memoryview:
mv = memoryview(bytearray(length))
total_read = 0
while total_read < length:
try:
read = conn.recv_into(mv[total_read:])
except BLOCKING_IO_ERRORS:
await asyncio.sleep(0.5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RIP windows perf. Can you schedule some of the windows tasks to see the impact?

For anyone else following along we'll fix this when we migrate to asyncio streams in: https://jira.mongodb.org/browse/PYTHON-4493

Copy link
Contributor Author

@NoahStapp NoahStapp Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a new change, it's already in master:

await asyncio.sleep(0.5)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right but the existing code was only for send() whereas this new code is for recv so the perf impact could be different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll schedule some Windows tasks.

read = 0
total_read += read
return mv


def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
sock.sendall(buf)


async def _poll_cancellation(conn: AsyncConnection) -> None:
while True:
if conn.cancel_context.cancelled:
return

await asyncio.sleep(_POLL_TIMEOUT)


async def async_receive_data(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
sock = conn.conn
sock_timeout = sock.gettimeout()
timeout: Optional[Union[float, int]]
if deadline:
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
timeout = max(deadline - time.monotonic(), 0)
else:
timeout = sock_timeout

sock.settimeout(0.0)
loop = asyncio.get_event_loop()
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
else:
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task]
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
if len(done) == 0:
raise socket.timeout("timed out")
for task in done:
if task == read_task:
return read_task.result()
else:
raise _OperationCancelled("operation cancelled")
return None # type: ignore[return-value]
finally:
sock.settimeout(sock_timeout)


async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
mv = memoryview(bytearray(length))
bytes_read = 0
while bytes_read < length:
chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:])
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
return mv


# Sync version:
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")


def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")

bytes_read += chunk_length

return mv
11 changes: 9 additions & 2 deletions pymongo/pyopenssl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,16 @@ def _ragged_eof(exc: BaseException) -> bool:
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
class _sslConn(_SSL.Connection):
def __init__(
self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool
self,
ctx: _SSL.Context,
sock: Optional[_socket.socket],
suppress_ragged_eofs: bool,
is_async: bool = False,
):
self.socket_checker = _SocketChecker()
self.suppress_ragged_eofs = suppress_ragged_eofs
super().__init__(ctx, sock)
self._is_async = is_async

def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
timeout = self.gettimeout()
Expand All @@ -119,6 +124,8 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
try:
return call(*args, **kwargs)
except BLOCKING_IO_ERRORS as exc:
if self._is_async:
raise exc
# Check for closed socket.
if self.fileno() == -1:
if timeout and _time.monotonic() - start > timeout:
Expand Down Expand Up @@ -381,7 +388,7 @@ async def a_wrap_socket(
"""Wrap an existing Python socket connection and return a TLS socket
object.
"""
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True)
loop = asyncio.get_running_loop()
if session:
ssl_conn.set_session(session)
Expand Down
Loading
Loading