diff --git a/doc/changelog.rst b/doc/changelog.rst index fce9daea08..237296e1f6 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -9,6 +9,12 @@ Version 4.12.1 is a bug fix release. - Fixed a bug that could raise ``UnboundLocalError`` when creating asynchronous connections over SSL. - Fixed a bug causing SRV hostname validation to fail when resolver and resolved hostnames are identical with three domain levels. - Fixed a bug that caused direct use of ``pymongo.uri_parser`` to raise an ``AttributeError``. +- Fixed a bug where clients created with connect=False and a "mongodb+srv://" connection string + could cause public ``pymongo.MongoClient`` and ``pymongo.AsyncMongoClient`` attributes (topology_description, + nodes, address, primary, secondaries, arbiters) to incorrectly return a Database, leading to type + errors such as: "NotImplementedError: Database objects do not implement truth value testing or bool()". +- Fixed a bug where MongoDB cluster topology changes could cause asynchronous operations to take much longer to complete + due to holding the Topology lock while closing stale connections. Issues Resolved ............... diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a67cc5f3c8..8b18ab927b 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import collections import contextlib import logging @@ -860,8 +861,14 @@ async def _reset( # PoolClosedEvent but that reset() SHOULD close sockets *after* # publishing the PoolClearedEvent. if close: - for conn in sockets: - await conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + if not _IS_SYNC: + await asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], + return_exceptions=True, + ) + else: + for conn in sockets: + await conn.close_conn(ConnectionClosedReason.POOL_CLOSED) if self.enabled_for_cmap: assert listeners is not None listeners.publish_pool_closed(self.address) @@ -891,8 +898,14 @@ async def _reset( serverPort=self.address[1], serviceId=service_id, ) - for conn in sockets: - await conn.close_conn(ConnectionClosedReason.STALE) + if not _IS_SYNC: + await asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], + return_exceptions=True, + ) + else: + for conn in sockets: + await conn.close_conn(ConnectionClosedReason.STALE) async def update_is_writable(self, is_writable: Optional[bool]) -> None: """Updates the is_writable attribute on all sockets currently in the @@ -938,8 +951,14 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds ): close_conns.append(self.conns.pop()) - for conn in close_conns: - await conn.close_conn(ConnectionClosedReason.IDLE) + if not _IS_SYNC: + await asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], + return_exceptions=True, + ) + else: + for conn in close_conns: + await conn.close_conn(ConnectionClosedReason.IDLE) while True: async with self.size_cond: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 32776bf7b9..438dd1e352 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -529,12 +529,6 @@ async def _process_change( if not _IS_SYNC: self._monitor_tasks.append(self._srv_monitor) - # Clear the pool from a failed heartbeat. - if reset_pool: - server = self._servers.get(server_description.address) - if server: - await server.pool.reset(interrupt_connections=interrupt_connections) - # Wake anything waiting in select_servers(). self._condition.notify_all() @@ -557,6 +551,11 @@ async def on_change( # that didn't include this server. if self._opened and self._description.has_server(server_description.address): await self._process_change(server_description, reset_pool, interrupt_connections) + # Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close. + if reset_pool: + server = self._servers.get(server_description.address) + if server: + await server.pool.reset(interrupt_connections=interrupt_connections) async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: """Process a new seedlist on an opened topology. diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 224834af31..b3eec64f27 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio import collections import contextlib import logging @@ -858,8 +859,14 @@ def _reset( # PoolClosedEvent but that reset() SHOULD close sockets *after* # publishing the PoolClearedEvent. if close: - for conn in sockets: - conn.close_conn(ConnectionClosedReason.POOL_CLOSED) + if not _IS_SYNC: + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets], + return_exceptions=True, + ) + else: + for conn in sockets: + conn.close_conn(ConnectionClosedReason.POOL_CLOSED) if self.enabled_for_cmap: assert listeners is not None listeners.publish_pool_closed(self.address) @@ -889,8 +896,14 @@ def _reset( serverPort=self.address[1], serviceId=service_id, ) - for conn in sockets: - conn.close_conn(ConnectionClosedReason.STALE) + if not _IS_SYNC: + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets], + return_exceptions=True, + ) + else: + for conn in sockets: + conn.close_conn(ConnectionClosedReason.STALE) def update_is_writable(self, is_writable: Optional[bool]) -> None: """Updates the is_writable attribute on all sockets currently in the @@ -934,8 +947,14 @@ def remove_stale_sockets(self, reference_generation: int) -> None: and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds ): close_conns.append(self.conns.pop()) - for conn in close_conns: - conn.close_conn(ConnectionClosedReason.IDLE) + if not _IS_SYNC: + asyncio.gather( + *[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], + return_exceptions=True, + ) + else: + for conn in close_conns: + conn.close_conn(ConnectionClosedReason.IDLE) while True: with self.size_cond: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index df23bff28c..1e99adf726 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -529,12 +529,6 @@ def _process_change( if not _IS_SYNC: self._monitor_tasks.append(self._srv_monitor) - # Clear the pool from a failed heartbeat. - if reset_pool: - server = self._servers.get(server_description.address) - if server: - server.pool.reset(interrupt_connections=interrupt_connections) - # Wake anything waiting in select_servers(). self._condition.notify_all() @@ -557,6 +551,11 @@ def on_change( # that didn't include this server. if self._opened and self._description.has_server(server_description.address): self._process_change(server_description, reset_pool, interrupt_connections) + # Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close. + if reset_pool: + server = self._servers.get(server_description.address) + if server: + server.pool.reset(interrupt_connections=interrupt_connections) def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None: """Process a new seedlist on an opened topology. diff --git a/test/__init__.py b/test/__init__.py index d8686e3257..ae5d60a384 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -823,6 +823,14 @@ def require_sync(self, func): lambda: _IS_SYNC, "This test only works with the synchronous API", func=func ) + def require_async(self, func): + """Run a test only if using the asynchronous API.""" # unasync: off + return self._require( + lambda: not _IS_SYNC, + "This test only works with the asynchronous API", # unasync: off + func=func, + ) + def mongos_seeds(self): return ",".join("{}:{}".format(*address) for address in self.mongoses) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 9e9cb9316d..b772da3126 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -825,6 +825,14 @@ def require_sync(self, func): lambda: _IS_SYNC, "This test only works with the synchronous API", func=func ) + def require_async(self, func): + """Run a test only if using the asynchronous API.""" # unasync: off + return self._require( + lambda: not _IS_SYNC, + "This test only works with the asynchronous API", # unasync: off + func=func, + ) + def mongos_seeds(self): return ",".join("{}:{}".format(*address) for address in self.mongoses) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index fa62b25dd1..cf26faf248 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -20,10 +20,15 @@ import socketserver import sys import threading +import time from asyncio import StreamReader, StreamWriter from pathlib import Path from test.asynchronous.helpers import ConcurrentRunner +from pymongo.asynchronous.pool import AsyncConnection +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector + sys.path[0:0] = [""] from test.asynchronous import ( @@ -370,6 +375,74 @@ async def test_pool_unpause(self): await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1) await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1) + @async_client_context.require_failCommand_appName + @async_client_context.require_test_commands + @async_client_context.require_async + async def test_connection_close_does_not_block_other_operations(self): + listener = CMAPHeartbeatListener() + client = await self.async_single_client( + appName="SDAMConnectionCloseTest", + event_listeners=[listener], + heartbeatFrequencyMS=500, + minPoolSize=10, + ) + server = await (await client._get_topology()).select_server( + writable_server_selector, _Op.TEST + ) + await async_wait_until( + lambda: len(server._pool.conns) == 10, + "pool initialized with 10 connections", + ) + + await client.db.test.insert_one({"x": 1}) + close_delay = 0.1 + latencies = [] + should_exit = [] + + async def run_task(): + while True: + start_time = time.monotonic() + await client.db.test.find_one({}) + elapsed = time.monotonic() - start_time + latencies.append(elapsed) + if should_exit: + break + await asyncio.sleep(0.001) + + task = ConcurrentRunner(target=run_task) + await task.start() + original_close = AsyncConnection.close_conn + try: + # Artificially delay the close operation to simulate a slow close + async def mock_close(self, reason): + await asyncio.sleep(close_delay) + await original_close(self, reason) + + AsyncConnection.close_conn = mock_close + + fail_hello = { + "mode": {"times": 4}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 91, + "appName": "SDAMConnectionCloseTest", + }, + } + async with self.fail_point(fail_hello): + # Wait for server heartbeat to fail + await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1) + # Wait until all idle connections are closed to simulate real-world conditions + await listener.async_wait_for_event(monitoring.ConnectionClosedEvent, 10) + # Wait for one more find to complete after the pool has been reset, then shutdown the task + n = len(latencies) + await async_wait_until(lambda: len(latencies) >= n + 1, "run one more find") + should_exit.append(True) + await task.join() + # No operation latency should not significantly exceed close_delay + self.assertLessEqual(max(latencies), close_delay * 5.0) + finally: + AsyncConnection.close_conn = original_close + class TestServerMonitoringMode(AsyncIntegrationTest): @async_client_context.require_no_serverless diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 07720473ca..9d6c945707 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -20,10 +20,15 @@ import socketserver import sys import threading +import time from asyncio import StreamReader, StreamWriter from pathlib import Path from test.helpers import ConcurrentRunner +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector +from pymongo.synchronous.pool import Connection + sys.path[0:0] = [""] from test import ( @@ -370,6 +375,72 @@ def test_pool_unpause(self): listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1) listener.wait_for_event(monitoring.PoolReadyEvent, 1) + @client_context.require_failCommand_appName + @client_context.require_test_commands + @client_context.require_async + def test_connection_close_does_not_block_other_operations(self): + listener = CMAPHeartbeatListener() + client = self.single_client( + appName="SDAMConnectionCloseTest", + event_listeners=[listener], + heartbeatFrequencyMS=500, + minPoolSize=10, + ) + server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST) + wait_until( + lambda: len(server._pool.conns) == 10, + "pool initialized with 10 connections", + ) + + client.db.test.insert_one({"x": 1}) + close_delay = 0.1 + latencies = [] + should_exit = [] + + def run_task(): + while True: + start_time = time.monotonic() + client.db.test.find_one({}) + elapsed = time.monotonic() - start_time + latencies.append(elapsed) + if should_exit: + break + time.sleep(0.001) + + task = ConcurrentRunner(target=run_task) + task.start() + original_close = Connection.close_conn + try: + # Artificially delay the close operation to simulate a slow close + def mock_close(self, reason): + time.sleep(close_delay) + original_close(self, reason) + + Connection.close_conn = mock_close + + fail_hello = { + "mode": {"times": 4}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 91, + "appName": "SDAMConnectionCloseTest", + }, + } + with self.fail_point(fail_hello): + # Wait for server heartbeat to fail + listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1) + # Wait until all idle connections are closed to simulate real-world conditions + listener.wait_for_event(monitoring.ConnectionClosedEvent, 10) + # Wait for one more find to complete after the pool has been reset, then shutdown the task + n = len(latencies) + wait_until(lambda: len(latencies) >= n + 1, "run one more find") + should_exit.append(True) + task.join() + # No operation latency should not significantly exceed close_delay + self.assertLessEqual(max(latencies), close_delay * 5.0) + finally: + Connection.close_conn = original_close + class TestServerMonitoringMode(IntegrationTest): @client_context.require_no_serverless diff --git a/tools/synchro.py b/tools/synchro.py index f6176e2038..bfe8f71125 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -288,7 +288,8 @@ def process_files( if file in docstring_translate_files: lines = translate_docstrings(lines) if file in sync_test_files: - translate_imports(lines) + lines = translate_imports(lines) + lines = process_ignores(lines) f.seek(0) f.writelines(lines) f.truncate() @@ -390,6 +391,14 @@ def translate_docstrings(lines: list[str]) -> list[str]: return [line for line in lines if line != "DOCSTRING_REMOVED"] +def process_ignores(lines: list[str]) -> list[str]: + for i in range(len(lines)): + for k, v in replacements.items(): + if "unasync: off" in lines[i] and v in lines[i]: + lines[i] = lines[i].replace(v, k) + return lines + + def unasync_directory(files: list[str], src: str, dest: str, replacements: dict[str, str]) -> None: unasync_files( files,