Skip to content

Commit a1a2109

Browse files
PYTHON-5089 Convert test.test_mongos_load_balancing to async (#2107)
Co-authored-by: Noah Stapp <[email protected]>
1 parent 25c9b90 commit a1a2109

File tree

3 files changed

+223
-17
lines changed

3 files changed

+223
-17
lines changed
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright 2015-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test AsyncMongoClient's mongos load balancing using a mock."""
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import sys
20+
import threading
21+
from test.asynchronous.helpers import ConcurrentRunner
22+
23+
from pymongo.operations import _Op
24+
25+
sys.path[0:0] = [""]
26+
27+
from test.asynchronous import AsyncMockClientTest, async_client_context, connected, unittest
28+
from test.asynchronous.pymongo_mocks import AsyncMockClient
29+
from test.utils import async_wait_until
30+
31+
from pymongo.errors import AutoReconnect, InvalidOperation
32+
from pymongo.server_selectors import writable_server_selector
33+
from pymongo.topology_description import TOPOLOGY_TYPE
34+
35+
_IS_SYNC = False
36+
37+
38+
class SimpleOp(ConcurrentRunner):
39+
def __init__(self, client):
40+
super().__init__()
41+
self.client = client
42+
self.passed = False
43+
44+
async def run(self):
45+
await self.client.db.command("ping")
46+
self.passed = True # No exception raised.
47+
48+
49+
async def do_simple_op(client, ntasks):
50+
tasks = [SimpleOp(client) for _ in range(ntasks)]
51+
for t in tasks:
52+
await t.start()
53+
54+
for t in tasks:
55+
await t.join()
56+
57+
for t in tasks:
58+
assert t.passed
59+
60+
61+
async def writable_addresses(topology):
62+
return {
63+
server.description.address
64+
for server in await topology.select_servers(writable_server_selector, _Op.TEST)
65+
}
66+
67+
68+
class TestMongosLoadBalancing(AsyncMockClientTest):
69+
@async_client_context.require_connection
70+
@async_client_context.require_no_load_balancer
71+
async def asyncSetUp(self):
72+
await super().asyncSetUp()
73+
74+
def mock_client(self, **kwargs):
75+
mock_client = AsyncMockClient(
76+
standalones=[],
77+
members=[],
78+
mongoses=["a:1", "b:2", "c:3"],
79+
host="a:1,b:2,c:3",
80+
connect=False,
81+
**kwargs,
82+
)
83+
self.addAsyncCleanup(mock_client.aclose)
84+
85+
# Latencies in seconds.
86+
mock_client.mock_rtts["a:1"] = 0.020
87+
mock_client.mock_rtts["b:2"] = 0.025
88+
mock_client.mock_rtts["c:3"] = 0.045
89+
return mock_client
90+
91+
async def test_lazy_connect(self):
92+
# While connected() ensures we can trigger connection from the main
93+
# thread and wait for the monitors, this test triggers connection from
94+
# several threads at once to check for data races.
95+
nthreads = 10
96+
client = self.mock_client()
97+
self.assertEqual(0, len(client.nodes))
98+
99+
# Trigger initial connection.
100+
await do_simple_op(client, nthreads)
101+
await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
102+
103+
async def test_failover(self):
104+
ntasks = 10
105+
client = await connected(self.mock_client(localThresholdMS=0.001))
106+
await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
107+
108+
# Our chosen mongos goes down.
109+
client.kill_host("a:1")
110+
111+
# Trigger failover to higher-latency nodes. AutoReconnect should be
112+
# raised at most once in each thread.
113+
passed = []
114+
115+
async def f():
116+
try:
117+
await client.db.command("ping")
118+
except AutoReconnect:
119+
# Second attempt succeeds.
120+
await client.db.command("ping")
121+
122+
passed.append(True)
123+
124+
tasks = [ConcurrentRunner(target=f) for _ in range(ntasks)]
125+
for t in tasks:
126+
await t.start()
127+
128+
for t in tasks:
129+
await t.join()
130+
131+
self.assertEqual(ntasks, len(passed))
132+
133+
# Down host removed from list.
134+
self.assertEqual(2, len(client.nodes))
135+
136+
async def test_local_threshold(self):
137+
client = await connected(self.mock_client(localThresholdMS=30))
138+
self.assertEqual(30, client.options.local_threshold_ms)
139+
await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
140+
topology = client._topology
141+
142+
# All are within a 30-ms latency window, see self.mock_client().
143+
self.assertEqual({("a", 1), ("b", 2), ("c", 3)}, await writable_addresses(topology))
144+
145+
# No error
146+
await client.admin.command("ping")
147+
148+
client = await connected(self.mock_client(localThresholdMS=0))
149+
self.assertEqual(0, client.options.local_threshold_ms)
150+
# No error
151+
await client.db.command("ping")
152+
# Our chosen mongos goes down.
153+
client.kill_host("{}:{}".format(*next(iter(client.nodes))))
154+
try:
155+
await client.db.command("ping")
156+
except:
157+
pass
158+
159+
# We eventually connect to a new mongos.
160+
async def connect_to_new_mongos():
161+
try:
162+
return await client.db.command("ping")
163+
except AutoReconnect:
164+
pass
165+
166+
await async_wait_until(connect_to_new_mongos, "connect to a new mongos")
167+
168+
async def test_load_balancing(self):
169+
# Although the server selection JSON tests already prove that
170+
# select_servers works for sharded topologies, here we do an end-to-end
171+
# test of discovering servers' round trip times and configuring
172+
# localThresholdMS.
173+
client = await connected(self.mock_client())
174+
await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
175+
176+
# Prohibited for topology type Sharded.
177+
with self.assertRaises(InvalidOperation):
178+
await client.address
179+
180+
topology = client._topology
181+
self.assertEqual(TOPOLOGY_TYPE.Sharded, topology.description.topology_type)
182+
183+
# a and b are within the 15-ms latency window, see self.mock_client().
184+
self.assertEqual({("a", 1), ("b", 2)}, await writable_addresses(topology))
185+
186+
client.mock_rtts["a:1"] = 0.045
187+
188+
# Discover only b is within latency window.
189+
async def predicate():
190+
return {("b", 2)} == await writable_addresses(topology)
191+
192+
await async_wait_until(
193+
predicate,
194+
'discover server "a" is too far',
195+
)
196+
197+
198+
if __name__ == "__main__":
199+
unittest.main()

test/test_mongos_load_balancing.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
"""Test MongoClient's mongos load balancing using a mock."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import sys
1920
import threading
21+
from test.helpers import ConcurrentRunner
2022

2123
from pymongo.operations import _Op
2224

@@ -30,14 +32,10 @@
3032
from pymongo.server_selectors import writable_server_selector
3133
from pymongo.topology_description import TOPOLOGY_TYPE
3234

35+
_IS_SYNC = True
3336

34-
@client_context.require_connection
35-
@client_context.require_no_load_balancer
36-
def setUpModule():
37-
pass
3837

39-
40-
class SimpleOp(threading.Thread):
38+
class SimpleOp(ConcurrentRunner):
4139
def __init__(self, client):
4240
super().__init__()
4341
self.client = client
@@ -48,15 +46,15 @@ def run(self):
4846
self.passed = True # No exception raised.
4947

5048

51-
def do_simple_op(client, nthreads):
52-
threads = [SimpleOp(client) for _ in range(nthreads)]
53-
for t in threads:
49+
def do_simple_op(client, ntasks):
50+
tasks = [SimpleOp(client) for _ in range(ntasks)]
51+
for t in tasks:
5452
t.start()
5553

56-
for t in threads:
54+
for t in tasks:
5755
t.join()
5856

59-
for t in threads:
57+
for t in tasks:
6058
assert t.passed
6159

6260

@@ -68,6 +66,11 @@ def writable_addresses(topology):
6866

6967

7068
class TestMongosLoadBalancing(MockClientTest):
69+
@client_context.require_connection
70+
@client_context.require_no_load_balancer
71+
def setUp(self):
72+
super().setUp()
73+
7174
def mock_client(self, **kwargs):
7275
mock_client = MockClient(
7376
standalones=[],
@@ -98,7 +101,7 @@ def test_lazy_connect(self):
98101
wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
99102

100103
def test_failover(self):
101-
nthreads = 10
104+
ntasks = 10
102105
client = connected(self.mock_client(localThresholdMS=0.001))
103106
wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses")
104107

@@ -118,14 +121,14 @@ def f():
118121

119122
passed.append(True)
120123

121-
threads = [threading.Thread(target=f) for _ in range(nthreads)]
122-
for t in threads:
124+
tasks = [ConcurrentRunner(target=f) for _ in range(ntasks)]
125+
for t in tasks:
123126
t.start()
124127

125-
for t in threads:
128+
for t in tasks:
126129
t.join()
127130

128-
self.assertEqual(nthreads, len(passed))
131+
self.assertEqual(ntasks, len(passed))
129132

130133
# Down host removed from list.
131134
self.assertEqual(2, len(client.nodes))
@@ -183,8 +186,11 @@ def test_load_balancing(self):
183186
client.mock_rtts["a:1"] = 0.045
184187

185188
# Discover only b is within latency window.
189+
def predicate():
190+
return {("b", 2)} == writable_addresses(topology)
191+
186192
wait_until(
187-
lambda: {("b", 2)} == writable_addresses(topology),
193+
predicate,
188194
'discover server "a" is too far',
189195
)
190196

tools/synchro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def async_only_test(f: str) -> bool:
221221
"test_logger.py",
222222
"test_max_staleness.py",
223223
"test_monitoring.py",
224+
"test_mongos_load_balancing.py",
224225
"test_on_demand_csfle.py",
225226
"test_raw_bson.py",
226227
"test_read_concern.py",

0 commit comments

Comments
 (0)