Skip to content

Commit 25c9b90

Browse files
authored
PYTHON-5099 Convert test.test_sdam_monitoring_spec to async (#2117)
1 parent 3dd44e6 commit 25c9b90

File tree

3 files changed

+397
-13
lines changed

3 files changed

+397
-13
lines changed
Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
# Copyright 2016 MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Run the sdam monitoring spec tests."""
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import json
20+
import os
21+
import sys
22+
import time
23+
from pathlib import Path
24+
25+
sys.path[0:0] = [""]
26+
27+
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs, unittest
28+
from test.utils import (
29+
ServerAndTopologyEventListener,
30+
async_wait_until,
31+
server_name_to_type,
32+
)
33+
34+
from bson.json_util import object_hook
35+
from pymongo import AsyncMongoClient, monitoring
36+
from pymongo.asynchronous.collection import AsyncCollection
37+
from pymongo.asynchronous.monitor import Monitor
38+
from pymongo.common import clean_node
39+
from pymongo.errors import ConnectionFailure, NotPrimaryError
40+
from pymongo.hello import Hello
41+
from pymongo.server_description import ServerDescription
42+
from pymongo.topology_description import TOPOLOGY_TYPE
43+
44+
_IS_SYNC = False
45+
46+
# Location of JSON test specifications.
47+
if _IS_SYNC:
48+
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sdam_monitoring")
49+
else:
50+
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sdam_monitoring")
51+
52+
53+
def compare_server_descriptions(expected, actual):
54+
if (expected["address"] != "{}:{}".format(*actual.address)) or (
55+
server_name_to_type(expected["type"]) != actual.server_type
56+
):
57+
return False
58+
expected_hosts = set(expected["arbiters"] + expected["passives"] + expected["hosts"])
59+
return expected_hosts == {"{}:{}".format(*s) for s in actual.all_hosts}
60+
61+
62+
def compare_topology_descriptions(expected, actual):
63+
if TOPOLOGY_TYPE.__getattribute__(expected["topologyType"]) != actual.topology_type:
64+
return False
65+
expected = expected["servers"]
66+
actual = actual.server_descriptions()
67+
if len(expected) != len(actual):
68+
return False
69+
for exp_server in expected:
70+
for _address, actual_server in actual.items():
71+
if compare_server_descriptions(exp_server, actual_server):
72+
break
73+
else:
74+
return False
75+
return True
76+
77+
78+
def compare_events(expected_dict, actual):
79+
if not expected_dict:
80+
return False, "Error: Bad expected value in YAML test"
81+
if not actual:
82+
return False, "Error: Event published was None"
83+
84+
expected_type, expected = list(expected_dict.items())[0]
85+
86+
if expected_type == "server_opening_event":
87+
if not isinstance(actual, monitoring.ServerOpeningEvent):
88+
return False, "Expected ServerOpeningEvent, got %s" % (actual.__class__)
89+
if expected["address"] != "{}:{}".format(*actual.server_address):
90+
return (
91+
False,
92+
"ServerOpeningEvent published with wrong address (expected" " {}, got {}".format(
93+
expected["address"], actual.server_address
94+
),
95+
)
96+
97+
elif expected_type == "server_description_changed_event":
98+
if not isinstance(actual, monitoring.ServerDescriptionChangedEvent):
99+
return (False, "Expected ServerDescriptionChangedEvent, got %s" % (actual.__class__))
100+
if expected["address"] != "{}:{}".format(*actual.server_address):
101+
return (
102+
False,
103+
"ServerDescriptionChangedEvent has wrong address" " (expected {}, got {}".format(
104+
expected["address"], actual.server_address
105+
),
106+
)
107+
108+
if not compare_server_descriptions(expected["newDescription"], actual.new_description):
109+
return (False, "New ServerDescription incorrect in ServerDescriptionChangedEvent")
110+
if not compare_server_descriptions(
111+
expected["previousDescription"], actual.previous_description
112+
):
113+
return (
114+
False,
115+
"Previous ServerDescription incorrect in ServerDescriptionChangedEvent",
116+
)
117+
118+
elif expected_type == "server_closed_event":
119+
if not isinstance(actual, monitoring.ServerClosedEvent):
120+
return False, "Expected ServerClosedEvent, got %s" % (actual.__class__)
121+
if expected["address"] != "{}:{}".format(*actual.server_address):
122+
return (
123+
False,
124+
"ServerClosedEvent published with wrong address" " (expected {}, got {}".format(
125+
expected["address"], actual.server_address
126+
),
127+
)
128+
129+
elif expected_type == "topology_opening_event":
130+
if not isinstance(actual, monitoring.TopologyOpenedEvent):
131+
return False, "Expected TopologyOpenedEvent, got %s" % (actual.__class__)
132+
133+
elif expected_type == "topology_description_changed_event":
134+
if not isinstance(actual, monitoring.TopologyDescriptionChangedEvent):
135+
return (
136+
False,
137+
"Expected TopologyDescriptionChangedEvent, got %s" % (actual.__class__),
138+
)
139+
if not compare_topology_descriptions(expected["newDescription"], actual.new_description):
140+
return (
141+
False,
142+
"New TopologyDescription incorrect in TopologyDescriptionChangedEvent",
143+
)
144+
if not compare_topology_descriptions(
145+
expected["previousDescription"], actual.previous_description
146+
):
147+
return (
148+
False,
149+
"Previous TopologyDescription incorrect in TopologyDescriptionChangedEvent",
150+
)
151+
152+
elif expected_type == "topology_await aclosed_event":
153+
if not isinstance(actual, monitoring.TopologyClosedEvent):
154+
return False, "Expected TopologyClosedEvent, got %s" % (actual.__class__)
155+
156+
else:
157+
return False, f"Incorrect event: expected {expected_type}, actual {actual}"
158+
159+
return True, ""
160+
161+
162+
def compare_multiple_events(i, expected_results, actual_results):
163+
events_in_a_row = []
164+
j = i
165+
while j < len(expected_results) and isinstance(actual_results[j], actual_results[i].__class__):
166+
events_in_a_row.append(actual_results[j])
167+
j += 1
168+
message = ""
169+
for event in events_in_a_row:
170+
for k in range(i, j):
171+
passed, message = compare_events(expected_results[k], event)
172+
if passed:
173+
expected_results[k] = None
174+
break
175+
else:
176+
return i, False, message
177+
return j, True, ""
178+
179+
180+
class TestAllScenarios(AsyncIntegrationTest):
181+
async def asyncSetUp(self):
182+
await super().asyncSetUp()
183+
self.all_listener = ServerAndTopologyEventListener()
184+
185+
186+
def create_test(scenario_def):
187+
async def run_scenario(self):
188+
with client_knobs(events_queue_frequency=0.05, min_heartbeat_interval=0.05):
189+
await _run_scenario(self)
190+
191+
async def _run_scenario(self):
192+
class NoopMonitor(Monitor):
193+
"""Override the _run method to do nothing."""
194+
195+
async def _run(self):
196+
await asyncio.sleep(0.05)
197+
198+
m = AsyncMongoClient(
199+
host=scenario_def["uri"],
200+
port=27017,
201+
event_listeners=[self.all_listener],
202+
_monitor_class=NoopMonitor,
203+
)
204+
topology = await m._get_topology()
205+
206+
try:
207+
for phase in scenario_def["phases"]:
208+
for source, response in phase.get("responses", []):
209+
source_address = clean_node(source)
210+
await topology.on_change(
211+
ServerDescription(
212+
address=source_address, hello=Hello(response), round_trip_time=0
213+
)
214+
)
215+
216+
expected_results = phase["outcome"]["events"]
217+
expected_len = len(expected_results)
218+
await async_wait_until(
219+
lambda: len(self.all_listener.results) >= expected_len,
220+
"publish all events",
221+
timeout=15,
222+
)
223+
224+
# Wait some time to catch possible lagging extra events.
225+
await async_wait_until(lambda: topology._events.empty(), "publish lagging events")
226+
227+
i = 0
228+
while i < expected_len:
229+
result = (
230+
self.all_listener.results[i] if len(self.all_listener.results) > i else None
231+
)
232+
# The order of ServerOpening/ClosedEvents doesn't matter
233+
if isinstance(
234+
result, (monitoring.ServerOpeningEvent, monitoring.ServerClosedEvent)
235+
):
236+
i, passed, message = compare_multiple_events(
237+
i, expected_results, self.all_listener.results
238+
)
239+
self.assertTrue(passed, message)
240+
else:
241+
self.assertTrue(*compare_events(expected_results[i], result))
242+
i += 1
243+
244+
# Assert no extra events.
245+
extra_events = self.all_listener.results[expected_len:]
246+
if extra_events:
247+
self.fail(f"Extra events {extra_events!r}")
248+
249+
self.all_listener.reset()
250+
finally:
251+
await m.close()
252+
253+
return run_scenario
254+
255+
256+
def create_tests():
257+
for dirpath, _, filenames in os.walk(TEST_PATH):
258+
for filename in filenames:
259+
with open(os.path.join(dirpath, filename)) as scenario_stream:
260+
scenario_def = json.load(scenario_stream, object_hook=object_hook)
261+
# Construct test from scenario.
262+
new_test = create_test(scenario_def)
263+
test_name = f"test_{os.path.splitext(filename)[0]}"
264+
new_test.__name__ = test_name
265+
setattr(TestAllScenarios, new_test.__name__, new_test)
266+
267+
268+
create_tests()
269+
270+
271+
class TestSdamMonitoring(AsyncIntegrationTest):
272+
knobs: client_knobs
273+
listener: ServerAndTopologyEventListener
274+
test_client: AsyncMongoClient
275+
coll: AsyncCollection
276+
277+
@classmethod
278+
def setUpClass(cls):
279+
# Speed up the tests by decreasing the event publish frequency.
280+
cls.knobs = client_knobs(
281+
events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1
282+
)
283+
cls.knobs.enable()
284+
cls.listener = ServerAndTopologyEventListener()
285+
286+
@classmethod
287+
def tearDownClass(cls):
288+
cls.knobs.disable()
289+
290+
@async_client_context.require_failCommand_fail_point
291+
async def asyncSetUp(self):
292+
await super().asyncSetUp()
293+
294+
retry_writes = async_client_context.supports_transactions()
295+
self.test_client = await self.async_rs_or_single_client(
296+
event_listeners=[self.listener], retryWrites=retry_writes
297+
)
298+
self.coll = self.test_client[self.client.db.name].test
299+
await self.coll.insert_one({})
300+
self.listener.reset()
301+
302+
async def asyncTearDown(self):
303+
await super().asyncTearDown()
304+
305+
async def _test_app_error(self, fail_command_opts, expected_error):
306+
address = await self.test_client.address
307+
308+
# Test that an application error causes a ServerDescriptionChangedEvent
309+
# to be published.
310+
data = {"failCommands": ["insert"]}
311+
data.update(fail_command_opts)
312+
fail_insert = {
313+
"configureFailPoint": "failCommand",
314+
"mode": {"times": 1},
315+
"data": data,
316+
}
317+
async with self.fail_point(fail_insert):
318+
if self.test_client.options.retry_writes:
319+
await self.coll.insert_one({})
320+
else:
321+
with self.assertRaises(expected_error):
322+
await self.coll.insert_one({})
323+
await self.coll.insert_one({})
324+
325+
def marked_unknown(event):
326+
return (
327+
isinstance(event, monitoring.ServerDescriptionChangedEvent)
328+
and event.server_address == address
329+
and not event.new_description.is_server_type_known
330+
)
331+
332+
def discovered_node(event):
333+
return (
334+
isinstance(event, monitoring.ServerDescriptionChangedEvent)
335+
and event.server_address == address
336+
and not event.previous_description.is_server_type_known
337+
and event.new_description.is_server_type_known
338+
)
339+
340+
def marked_unknown_and_rediscovered():
341+
return (
342+
len(self.listener.matching(marked_unknown)) >= 1
343+
and len(self.listener.matching(discovered_node)) >= 1
344+
)
345+
346+
# Topology events are not published synchronously
347+
await async_wait_until(marked_unknown_and_rediscovered, "rediscover node")
348+
349+
# Expect a single ServerDescriptionChangedEvent for the network error.
350+
marked_unknown_events = self.listener.matching(marked_unknown)
351+
self.assertEqual(len(marked_unknown_events), 1, marked_unknown_events)
352+
self.assertIsInstance(marked_unknown_events[0].new_description.error, expected_error)
353+
354+
async def test_network_error_publishes_events(self):
355+
await self._test_app_error({"closeConnection": True}, ConnectionFailure)
356+
357+
# In 4.4+, not primary errors from failCommand don't cause SDAM state
358+
# changes because topologyVersion is not incremented.
359+
@async_client_context.require_version_max(4, 3)
360+
async def test_not_primary_error_publishes_events(self):
361+
await self._test_app_error(
362+
{"errorCode": 10107, "closeConnection": False, "errorLabels": ["RetryableWriteError"]},
363+
NotPrimaryError,
364+
)
365+
366+
async def test_shutdown_error_publishes_events(self):
367+
await self._test_app_error(
368+
{"errorCode": 91, "closeConnection": False, "errorLabels": ["RetryableWriteError"]},
369+
NotPrimaryError,
370+
)
371+
372+
373+
if __name__ == "__main__":
374+
unittest.main()

0 commit comments

Comments
 (0)