diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 5d3cfcd832..15d42cfeb2 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -1277,6 +1277,7 @@ def _select_server( server_selector: Callable[[Selection], Selection], session: Optional[ClientSession], address: Optional[_Address] = None, + deprioritized_servers: Optional[list[Server]] = None, ) -> Server: """Select a server to run an operation on this client. @@ -1300,7 +1301,9 @@ def _select_server( if not server: raise AutoReconnect("server %s:%s no longer available" % address) # noqa: UP031 else: - server = topology.select_server(server_selector) + server = topology.select_server( + server_selector, deprioritized_servers=deprioritized_servers + ) return server except PyMongoError as exc: # Server selection errors in a transaction are transient. @@ -2291,6 +2294,7 @@ def __init__( ) self._address = address self._server: Server = None # type: ignore + self._deprioritized_servers: list[Server] = [] def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2359,6 +2363,9 @@ def run(self) -> T: if self._last_error is None: self._last_error = exc + if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded: + self._deprioritized_servers.append(self._server) + def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" return not self._retryable or (self._is_retrying() and not self._multiple_retries) @@ -2397,7 +2404,10 @@ def _get_server(self) -> Server: Abstraction to connect to server """ return self._client._select_server( - self._server_selector, self._session, address=self._address + self._server_selector, + self._session, + address=self._address, + deprioritized_servers=self._deprioritized_servers, ) def _write(self) -> T: diff --git a/pymongo/topology.py b/pymongo/topology.py index 786be3ec93..092c7d92af 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -282,8 +282,10 @@ def _select_server( selector: Callable[[Selection], Selection], server_selection_timeout: Optional[float] = None, address: Optional[_Address] = None, + deprioritized_servers: Optional[list[Server]] = None, ) -> Server: servers = self.select_servers(selector, server_selection_timeout, address) + servers = _filter_servers(servers, deprioritized_servers) if len(servers) == 1: return servers[0] server1, server2 = random.sample(servers, 2) @@ -297,9 +299,12 @@ def select_server( selector: Callable[[Selection], Selection], server_selection_timeout: Optional[float] = None, address: Optional[_Address] = None, + deprioritized_servers: Optional[list[Server]] = None, ) -> Server: """Like select_servers, but choose a random server if several match.""" - server = self._select_server(selector, server_selection_timeout, address) + server = self._select_server( + selector, server_selection_timeout, address, deprioritized_servers + ) if _csot.get_timeout(): _csot.set_rtt(server.description.min_round_trip_time) return server @@ -931,3 +936,16 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe if current_tv["processId"] != new_tv["processId"]: return False return current_tv["counter"] > new_tv["counter"] + + +def _filter_servers( + candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None +) -> list[Server]: + """Filter out deprioritized servers from a list of server candidates.""" + if not deprioritized_servers: + return candidates + + filtered = [server for server in candidates if server not in deprioritized_servers] + + # If not possible to pick a prioritized server, return the original list + return filtered or candidates diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 8779ea1ed8..e3028688d7 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -20,6 +20,9 @@ import sys import threading +from bson import SON +from pymongo.errors import AutoReconnect + sys.path[0:0] = [""] from test import ( @@ -31,9 +34,12 @@ ) from test.utils import ( CMAPListener, + EventListener, OvertCommandListener, SpecTestCreator, + rs_client, rs_or_single_client, + set_fail_point, ) from test.utils_spec_runner import SpecRunner @@ -221,5 +227,48 @@ def test_pool_paused_error_is_retryable(self): self.assertEqual(1, len(failed), msg) +class TestRetryableReads(IntegrationTest): + @client_context.require_multiple_mongoses + @client_context.require_failCommand_fail_point + def test_retryable_reads_in_sharded_cluster_multiple_available(self): + fail_command = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "closeConnection": True, + "appName": "retryableReadTest", + }, + } + + mongos_clients = [] + + for mongos in client_context.mongos_seeds().split(","): + client = rs_or_single_client(mongos) + set_fail_point(client, fail_command) + self.addCleanup(client.close) + mongos_clients.append(client) + + listener = OvertCommandListener() + client = rs_or_single_client( + client_context.mongos_seeds(), + appName="retryableReadTest", + event_listeners=[listener], + retryReads=True, + ) + + with self.fail_point(fail_command): + with self.assertRaises(AutoReconnect): + client.t.t.find_one({}) + + # Disable failpoints on each mongos + for client in mongos_clients: + fail_command["mode"] = "off" + set_fail_point(client, fail_command) + + self.assertEqual(len(listener.failed_events), 2) + self.assertEqual(len(listener.succeeded_events), 0) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 2da6f53f4b..98bf0e5c94 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -31,6 +31,7 @@ OvertCommandListener, SpecTestCreator, rs_or_single_client, + set_fail_point, ) from test.utils_spec_runner import SpecRunner from test.version import Version @@ -40,6 +41,7 @@ from bson.raw_bson import RawBSONDocument from bson.son import SON from pymongo.errors import ( + AutoReconnect, ConnectionFailure, OperationFailure, ServerSelectionTimeoutError, @@ -469,6 +471,46 @@ def test_batch_splitting_retry_fails(self): self.assertEqual(final_txn, expected_txn) self.assertEqual(coll.find_one(projection={"_id": True}), {"_id": 1}) + @client_context.require_multiple_mongoses + @client_context.require_failCommand_fail_point + def test_retryable_writes_in_sharded_cluster_multiple_available(self): + fail_command = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "closeConnection": True, + "appName": "retryableWriteTest", + }, + } + + mongos_clients = [] + + for mongos in client_context.mongos_seeds().split(","): + client = rs_or_single_client(mongos) + set_fail_point(client, fail_command) + self.addCleanup(client.close) + mongos_clients.append(client) + + listener = OvertCommandListener() + client = rs_or_single_client( + client_context.mongos_seeds(), + appName="retryableWriteTest", + event_listeners=[listener], + retryWrites=True, + ) + + with self.assertRaises(AutoReconnect): + client.t.t.insert_one({"x": 1}) + + # Disable failpoints on each mongos + for client in mongos_clients: + fail_command["mode"] = "off" + set_fail_point(client, fail_command) + + self.assertEqual(len(listener.failed_events), 2) + self.assertEqual(len(listener.succeeded_events), 0) + class TestWriteConcernError(IntegrationTest): RUN_ON_LOAD_BALANCER = True diff --git a/test/test_topology.py b/test/test_topology.py index 88c99d2a28..1da42a100b 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -30,11 +30,12 @@ from pymongo.monitor import Monitor from pymongo.pool import PoolOptions from pymongo.read_preferences import ReadPreference, Secondary +from pymongo.server import Server from pymongo.server_description import ServerDescription from pymongo.server_selectors import any_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.settings import TopologySettings -from pymongo.topology import Topology, _ErrorContext +from pymongo.topology import Topology, _ErrorContext, _filter_servers from pymongo.topology_description import TOPOLOGY_TYPE @@ -681,6 +682,23 @@ def test_unexpected_load_balancer(self): self.assertNotIn(("a", 27017), t.description.server_descriptions()) self.assertEqual(t.description.topology_type_name, "Unknown") + def test_filtered_server_selection(self): + s1 = Server(ServerDescription(("localhost", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type] + s2 = Server(ServerDescription(("localhost2", 27017)), pool=object(), monitor=object()) # type: ignore[arg-type] + servers = [s1, s2] + + result = _filter_servers(servers, deprioritized_servers=[s2]) + self.assertEqual(result, [s1]) + + result = _filter_servers(servers, deprioritized_servers=[s1, s2]) + self.assertEqual(result, servers) + + result = _filter_servers(servers, deprioritized_servers=[]) + self.assertEqual(result, servers) + + result = _filter_servers(servers) + self.assertEqual(result, servers) + def wait_for_primary(topology): """Wait for a Topology to discover a writable server. diff --git a/test/utils.py b/test/utils.py index c8f9197c64..209d022a5b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1153,3 +1153,9 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callbac raise AssertionError(f"Unsupported cursorType: {cursor_type}") else: arguments[c2s] = arguments.pop(arg_name) + + +def set_fail_point(client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + client.admin.command(cmd)