Skip to content

PYTHON-4324 CSOT avoid connection churn when operations timeout #2269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pymongo/asynchronous/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ async def command(
spec = orig = await client._encrypter.encrypt(dbname, spec, codec_options)

# Support CSOT
applied_csot = False
if client:
conn.apply_timeout(client, spec)
res = conn.apply_timeout(client, spec)
applied_csot = bool(res)
_csot.apply_write_concern(spec, write_concern)

if use_op_msg:
Expand Down Expand Up @@ -195,7 +197,7 @@ async def command(
reply = None
response_doc: _DocumentOut = {"ok": 1}
else:
reply = await async_receive_message(conn, request_id)
reply = await async_receive_message(conn, request_id, enable_pending=applied_csot)
conn.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields
Expand Down
58 changes: 52 additions & 6 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
)

from bson import DEFAULT_CODEC_OPTIONS
from pymongo import _csot, helpers_shared
from pymongo import _csot, helpers_shared, network_layer
from pymongo.asynchronous.client_session import _validate_session_write_concern
from pymongo.asynchronous.helpers import _handle_reauth
from pymongo.asynchronous.network import command
Expand Down Expand Up @@ -188,6 +188,42 @@ def __init__(
self.creation_time = time.monotonic()
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None
self.pending_response = False
self.pending_bytes = 0
self.pending_deadline = 0.0

def mark_pending(self, nbytes: int) -> None:
"""Mark this connection as having a pending response."""
self.pending_response = True
self.pending_bytes = nbytes
self.pending_deadline = time.monotonic() + 3 # 3 seconds timeout for pending response

async def complete_pending(self) -> None:
"""Complete a pending response."""
if not self.pending_response:
return

if _csot.get_timeout():
deadline = min(_csot.get_deadline(), self.pending_deadline)
else:
timeout = self.conn.gettimeout
if timeout is not None:
deadline = min(time.monotonic() + timeout, self.pending_deadline)
else:
deadline = self.pending_deadline

if not _IS_SYNC:
# In async the reader task reads the whole message at once.
# TODO: respect deadline
await self.receive_message(None, True)
else:
try:
network_layer.receive_data(self, self.pending_bytes, deadline, True) # type:ignore[arg-type]
except BaseException as error:
await self._raise_connection_failure(error)
self.pending_response = False
self.pending_bytes = 0
self.pending_deadline = 0.0

def set_conn_timeout(self, timeout: Optional[float]) -> None:
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
Expand Down Expand Up @@ -454,13 +490,17 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
except BaseException as error:
await self._raise_connection_failure(error)

async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
async def receive_message(
self, request_id: Optional[int], enable_pending: bool = False
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise ConnectionFailure.

If any exception is raised, the socket is closed.
"""
try:
return await async_receive_message(self, request_id, self.max_message_size)
return await async_receive_message(
self, request_id, self.max_message_size, enable_pending
)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
await self._raise_connection_failure(error)
Expand Down Expand Up @@ -495,7 +535,9 @@ async def write_command(
:param msg: bytes, the command message.
"""
await self.send_message(msg, 0)
reply = await self.receive_message(request_id)
reply = await self.receive_message(
request_id, enable_pending=(_csot.get_timeout() is not None)
)
result = reply.command_response(codec_options)

# Raises NotPrimaryError or OperationFailure.
Expand Down Expand Up @@ -635,7 +677,10 @@ async def _raise_connection_failure(self, error: BaseException) -> NoReturn:
reason = None
else:
reason = ConnectionClosedReason.ERROR
await self.close_conn(reason)

# Pending connections should be placed back in the pool.
if not self.pending_response:
await self.close_conn(reason)
# SSLError from PyOpenSSL inherits directly from Exception.
if isinstance(error, (IOError, OSError, SSLError)):
details = _get_timeout_details(self.opts)
Expand Down Expand Up @@ -1076,7 +1121,7 @@ async def checkout(

This method should always be used in a with-statement::

with pool.get_conn() as connection:
with pool.checkout() as connection:
connection.send_message(msg)
data = connection.receive_message(op_code, request_id)

Expand Down Expand Up @@ -1388,6 +1433,7 @@ async def _perished(self, conn: AsyncConnection) -> bool:
pool, to keep performance reasonable - we can't avoid AutoReconnects
completely anyway.
"""
await conn.complete_pending()
idle_time_seconds = conn.idle_time_seconds()
# If socket is idle, open a new one.
if (
Expand Down
2 changes: 1 addition & 1 deletion pymongo/asynchronous/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ async def run_operation(
reply = await conn.receive_message(None)
else:
await conn.send_message(data, max_doc_size)
reply = await conn.receive_message(request_id)
reply = await conn.receive_message(request_id, operation.pending_enabled())

# Unpack and check for command errors.
if use_cmd:
Expand Down
18 changes: 16 additions & 2 deletions pymongo/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,7 @@ class _Query:
"allow_disk_use",
"_as_command",
"exhaust",
"_pending_enabled",
)

# For compatibility with the _GetMore class.
Expand Down Expand Up @@ -1612,6 +1613,10 @@ def __init__(
self.name = "find"
self._as_command: Optional[tuple[dict[str, Any], str]] = None
self.exhaust = exhaust
self._pending_enabled = False

def pending_enabled(self) -> bool:
return self._pending_enabled

def reset(self) -> None:
self._as_command = None
Expand Down Expand Up @@ -1673,7 +1678,9 @@ def as_command(
conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type]
# Support CSOT
if apply_timeout:
conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type]
res = conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type]
if res is not None:
self._pending_enabled = True
self._as_command = cmd, self.db
return self._as_command

Expand Down Expand Up @@ -1747,6 +1754,7 @@ class _GetMore:
"_as_command",
"exhaust",
"comment",
"_pending_enabled",
)

name = "getMore"
Expand Down Expand Up @@ -1779,6 +1787,10 @@ def __init__(
self._as_command: Optional[tuple[dict[str, Any], str]] = None
self.exhaust = exhaust
self.comment = comment
self._pending_enabled = False

def pending_enabled(self) -> bool:
return self._pending_enabled

def reset(self) -> None:
self._as_command = None
Expand Down Expand Up @@ -1822,7 +1834,9 @@ def as_command(
conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type]
# Support CSOT
if apply_timeout:
conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type]
res = conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type]
if res is not None:
self._pending_enabled = True
self._as_command = cmd, self.db
return self._as_command

Expand Down
33 changes: 25 additions & 8 deletions pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
raise socket.timeout("timed out")


def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
def receive_data(
conn: Connection, length: int, deadline: Optional[float], enable_pending: bool = False
) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
Expand All @@ -336,7 +338,7 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
# When the timeout has expired we perform one final non-blocking recv.
# This helps avoid spurious timeouts when the response is actually already
# buffered on the client.
orig_timeout = conn.conn.gettimeout()
orig_timeout = conn.conn.gettimeout
try:
while bytes_read < length:
try:
Expand All @@ -357,12 +359,16 @@ def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> me
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
# We reached the true deadline.
if enable_pending:
conn.mark_pending(length - bytes_read)
raise socket.timeout("timed out") from None
except socket.timeout:
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled") from None
if _PYPY:
# We reached the true deadline.
if enable_pending:
conn.mark_pending(length - bytes_read)
raise
continue
except OSError as exc:
Expand Down Expand Up @@ -438,6 +444,7 @@ class NetworkingInterface(NetworkingInterfaceBase):
def __init__(self, conn: Union[socket.socket, _sslConn]):
super().__init__(conn)

@property
def gettimeout(self) -> float | None:
return self.conn.gettimeout()

Expand Down Expand Up @@ -692,6 +699,7 @@ async def async_receive_message(
conn: AsyncConnection,
request_id: Optional[int],
max_message_size: int = MAX_MESSAGE_SIZE,
enable_pending: bool = False,
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
timeout: Optional[Union[float, int]]
Expand Down Expand Up @@ -721,6 +729,8 @@ async def async_receive_message(
if pending:
await asyncio.wait(pending)
if len(done) == 0:
if enable_pending:
conn.mark_pending(1)
raise socket.timeout("timed out")
if read_task in done:
data, op_code = read_task.result()
Expand All @@ -740,19 +750,24 @@ async def async_receive_message(


def receive_message(
conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE
conn: Connection,
request_id: Optional[int],
max_message_size: int = MAX_MESSAGE_SIZE,
enable_pending: bool = False,
) -> Union[_OpReply, _OpMsg]:
"""Receive a raw BSON message or raise socket.error."""
if _csot.get_timeout():
deadline = _csot.get_deadline()
else:
timeout = conn.conn.gettimeout()
timeout = conn.conn.gettimeout
if timeout:
deadline = time.monotonic() + timeout
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
length, _, response_to, op_code = _UNPACK_HEADER(
receive_data(conn, 16, deadline, enable_pending)
)
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
Expand All @@ -767,10 +782,12 @@ def receive_message(
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
receive_data(conn, 9, deadline, enable_pending)
)
data = decompress(receive_data(conn, length - 25, deadline, enable_pending), compressor_id)
else:
data = receive_data(conn, length - 16, deadline)
data = receive_data(conn, length - 16, deadline, enable_pending)

try:
unpack_reply = _UNPACK_REPLY[op_code]
Expand Down
6 changes: 4 additions & 2 deletions pymongo/synchronous/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ def command(
spec = orig = client._encrypter.encrypt(dbname, spec, codec_options)

# Support CSOT
applied_csot = False
if client:
conn.apply_timeout(client, spec)
res = conn.apply_timeout(client, spec)
applied_csot = bool(res)
_csot.apply_write_concern(spec, write_concern)

if use_op_msg:
Expand Down Expand Up @@ -195,7 +197,7 @@ def command(
reply = None
response_doc: _DocumentOut = {"ok": 1}
else:
reply = receive_message(conn, request_id)
reply = receive_message(conn, request_id, enable_pending=applied_csot)
conn.more_to_come = reply.more_to_come
unpacked_docs = reply.unpack_response(
codec_options=codec_options, user_fields=user_fields
Expand Down
Loading
Loading