Skip to content

Commit 9589b1f

Browse files
committed
PYTHON-4927 - Add missing CSOT prose tests
1 parent 2f1227c commit 9589b1f

19 files changed

+645
-15
lines changed

pymongo/_csot.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,24 @@
1616

1717
from __future__ import annotations
1818

19+
import contextlib
1920
import functools
2021
import inspect
2122
import time
2223
from collections import deque
2324
from contextlib import AbstractContextManager
2425
from contextvars import ContextVar, Token
25-
from typing import TYPE_CHECKING, Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast
26+
from typing import (
27+
TYPE_CHECKING,
28+
Any,
29+
Callable,
30+
Deque,
31+
Generator,
32+
MutableMapping,
33+
Optional,
34+
TypeVar,
35+
cast,
36+
)
2637

2738
if TYPE_CHECKING:
2839
from pymongo.write_concern import WriteConcern
@@ -54,6 +65,13 @@ def remaining() -> Optional[float]:
5465
return DEADLINE.get() - time.monotonic()
5566

5667

68+
@contextlib.contextmanager
69+
def reset() -> Generator:
70+
deadline_token = DEADLINE.set(DEADLINE.get() + get_timeout()) # type: ignore[operator]
71+
yield
72+
DEADLINE.reset(deadline_token)
73+
74+
5775
def clamp_remaining(max_timeout: float) -> float:
5876
"""Return the remaining timeout clamped to a max value."""
5977
timeout = remaining()

pymongo/asynchronous/client_session.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,11 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
473473

474474
def _within_time_limit(start_time: float) -> bool:
475475
"""Are we within the with_transaction retry limit?"""
476-
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
476+
timeout = _csot.get_timeout()
477+
if timeout:
478+
return time.monotonic() - start_time < timeout
479+
else:
480+
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
477481

478482

479483
_T = TypeVar("_T")
@@ -512,6 +516,7 @@ def __init__(
512516
# Is this an implicitly created session?
513517
self._implicit = implicit
514518
self._transaction = _Transaction(None, client)
519+
self._timeout = client.options.timeout
515520

516521
async def end_session(self) -> None:
517522
"""Finish this session. If a transaction has started, abort it.
@@ -597,6 +602,7 @@ def _inherit_option(self, name: str, val: _T) -> _T:
597602
return parent_val
598603
return getattr(self.client, name)
599604

605+
@_csot.apply
600606
async def with_transaction(
601607
self,
602608
callback: Callable[[AsyncClientSession], Coroutine[Any, Any, _T]],
@@ -697,7 +703,8 @@ async def callback(session, custom_arg, custom_kwarg=None):
697703
ret = await callback(self)
698704
except Exception as exc:
699705
if self.in_transaction:
700-
await self.abort_transaction()
706+
with _csot.reset():
707+
await self.abort_transaction()
701708
if (
702709
isinstance(exc, PyMongoError)
703710
and exc.has_error_label("TransientTransactionError")
@@ -816,6 +823,7 @@ async def commit_transaction(self) -> None:
816823
finally:
817824
self._transaction.state = _TxnState.COMMITTED
818825

826+
@_csot.apply
819827
async def abort_transaction(self) -> None:
820828
"""Abort a multi-statement transaction.
821829

pymongo/asynchronous/topology.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ def get_server_selection_timeout(self) -> float:
249249
timeout = _csot.remaining()
250250
if timeout is None:
251251
return self._settings.server_selection_timeout
252-
return timeout
252+
else:
253+
return min(timeout, self._settings.server_selection_timeout)
253254

254255
async def select_servers(
255256
self,

pymongo/synchronous/client_session.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,11 @@ def _max_time_expired_error(exc: PyMongoError) -> bool:
472472

473473
def _within_time_limit(start_time: float) -> bool:
474474
"""Are we within the with_transaction retry limit?"""
475-
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
475+
timeout = _csot.get_timeout()
476+
if timeout:
477+
return time.monotonic() - start_time < timeout
478+
else:
479+
return time.monotonic() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT
476480

477481

478482
_T = TypeVar("_T")
@@ -511,6 +515,7 @@ def __init__(
511515
# Is this an implicitly created session?
512516
self._implicit = implicit
513517
self._transaction = _Transaction(None, client)
518+
self._timeout = client.options.timeout
514519

515520
def end_session(self) -> None:
516521
"""Finish this session. If a transaction has started, abort it.
@@ -596,6 +601,7 @@ def _inherit_option(self, name: str, val: _T) -> _T:
596601
return parent_val
597602
return getattr(self.client, name)
598603

604+
@_csot.apply
599605
def with_transaction(
600606
self,
601607
callback: Callable[[ClientSession], _T],
@@ -694,7 +700,8 @@ def callback(session, custom_arg, custom_kwarg=None):
694700
ret = callback(self)
695701
except Exception as exc:
696702
if self.in_transaction:
697-
self.abort_transaction()
703+
with _csot.reset():
704+
self.abort_transaction()
698705
if (
699706
isinstance(exc, PyMongoError)
700707
and exc.has_error_label("TransientTransactionError")
@@ -813,6 +820,7 @@ def commit_transaction(self) -> None:
813820
finally:
814821
self._transaction.state = _TxnState.COMMITTED
815822

823+
@_csot.apply
816824
def abort_transaction(self) -> None:
817825
"""Abort a multi-statement transaction.
818826

pymongo/synchronous/topology.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ def get_server_selection_timeout(self) -> float:
249249
timeout = _csot.remaining()
250250
if timeout is None:
251251
return self._settings.server_selection_timeout
252-
return timeout
252+
else:
253+
return min(timeout, self._settings.server_selection_timeout)
253254

254255
def select_servers(
255256
self,

test/asynchronous/test_client.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
from test.utils import (
6565
NTHREADS,
6666
CMAPListener,
67+
EventListener,
6768
FunctionCallRecorder,
6869
async_get_pool,
6970
async_wait_until,
@@ -114,7 +115,13 @@
114115
ServerSelectionTimeoutError,
115116
WriteConcernError,
116117
)
117-
from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent
118+
from pymongo.monitoring import (
119+
ConnectionClosedEvent,
120+
ConnectionCreatedEvent,
121+
ConnectionReadyEvent,
122+
ServerHeartbeatListener,
123+
ServerHeartbeatStartedEvent,
124+
)
118125
from pymongo.pool_options import _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, PoolOptions
119126
from pymongo.read_preferences import ReadPreference
120127
from pymongo.server_description import ServerDescription
@@ -2585,5 +2592,43 @@ async def test_direct_client_maintains_pool_to_arbiter(self):
25852592
self.assertEqual(listener.event_count(monitoring.PoolReadyEvent), 1)
25862593

25872594

2595+
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#4-background-connection-pooling
2596+
class TestClientCSOTProse(AsyncIntegrationTest):
2597+
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#timeoutms-is-refreshed-for-each-handshake-command
2598+
@async_client_context.require_auth
2599+
@async_client_context.require_version_min(4, 4, -1)
2600+
async def test_02_timeoutMS_refreshed_for_each_handshake_command(self):
2601+
listener = CMAPListener()
2602+
2603+
async with self.fail_point(
2604+
{
2605+
"mode": {"times": 1},
2606+
"data": {
2607+
"failCommands": ["hello", "isMaster", "saslContinue"],
2608+
"blockConnection": True,
2609+
"blockTimeMS": 15,
2610+
"appName": "refreshTimeoutBackgroundPoolTest",
2611+
},
2612+
}
2613+
):
2614+
_ = await self.async_single_client(
2615+
minPoolSize=1,
2616+
timeoutMS=20,
2617+
appname="refreshTimeoutBackgroundPoolTest",
2618+
event_listeners=[listener],
2619+
)
2620+
2621+
async def predicate():
2622+
return (
2623+
listener.event_count(ConnectionCreatedEvent) == 1
2624+
and listener.event_count(ConnectionReadyEvent) == 1
2625+
)
2626+
2627+
await async_wait_until(
2628+
predicate,
2629+
"didn't ever see a ConnectionCreatedEvent and a ConnectionReadyEvent",
2630+
)
2631+
2632+
25882633
if __name__ == "__main__":
25892634
unittest.main()

test/asynchronous/test_collection.py

+26
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
InvalidDocument,
6565
InvalidName,
6666
InvalidOperation,
67+
NetworkTimeout,
6768
OperationFailure,
6869
WriteConcernError,
6970
)
@@ -2277,6 +2278,31 @@ async def afind(*args, **kwargs):
22772278
for helper, args in helpers:
22782279
await helper(*args, let={}) # type: ignore
22792280

2281+
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#1-multi-batch-inserts
2282+
@async_client_context.require_standalone
2283+
@async_client_context.require_version_min(4, 4, -1)
2284+
async def test_01_multi_batch_inserts(self):
2285+
client = await self.async_single_client(read_preference=ReadPreference.PRIMARY_PREFERRED)
2286+
await client.db.coll.drop()
2287+
2288+
async with self.fail_point(
2289+
{
2290+
"mode": {"times": 2},
2291+
"data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 1010},
2292+
}
2293+
):
2294+
listener = OvertCommandListener()
2295+
client2 = await self.async_single_client(
2296+
timeoutMS=2000,
2297+
read_preference=ReadPreference.PRIMARY_PREFERRED,
2298+
event_listeners=[listener],
2299+
)
2300+
docs = [{"a": "b" * 1000000} for _ in range(50)]
2301+
with self.assertRaises(NetworkTimeout):
2302+
await client2.db.coll.insert_many(docs)
2303+
2304+
self.assertEqual(2, len(listener.started_events))
2305+
22802306

22812307
if __name__ == "__main__":
22822308
unittest.main()

test/asynchronous/test_encryption.py

+103
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
EncryptedCollectionError,
8787
EncryptionError,
8888
InvalidOperation,
89+
NetworkTimeout,
8990
OperationFailure,
9091
ServerSelectionTimeoutError,
9192
WriteError,
@@ -3133,5 +3134,107 @@ async def test_explicit_session_errors_when_unsupported(self):
31333134
await self.mongocryptd_client.db.test.insert_one({"x": 1}, session=s)
31343135

31353136

3137+
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#3-clientencryption
3138+
class TestCSOTProse(AsyncEncryptionIntegrationTest):
3139+
mongocryptd_client: AsyncMongoClient
3140+
MONGOCRYPTD_PORT = 27020
3141+
LOCAL_MASTERKEY = Binary(
3142+
base64.b64decode(
3143+
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
3144+
),
3145+
UUID_SUBTYPE,
3146+
)
3147+
3148+
async def asyncSetUp(self) -> None:
3149+
self.listener = OvertCommandListener()
3150+
self.client = await self.async_single_client(
3151+
read_preference=ReadPreference.PRIMARY_PREFERRED, event_listeners=[self.listener]
3152+
)
3153+
await self.client.keyvault.datakeys.drop()
3154+
self.key_vault_client = await self.async_rs_or_single_client(
3155+
timeoutMS=50, event_listeners=[self.listener]
3156+
)
3157+
self.client_encryption = self.create_client_encryption(
3158+
key_vault_namespace="keyvault.datakeys",
3159+
kms_providers={"local": {"key": self.LOCAL_MASTERKEY}},
3160+
key_vault_client=self.key_vault_client,
3161+
codec_options=OPTS,
3162+
)
3163+
3164+
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#createdatakey
3165+
async def test_01_create_data_key(self):
3166+
async with self.fail_point(
3167+
{
3168+
"mode": {"times": 1},
3169+
"data": {"failCommands": ["insert"], "blockConnection": True, "blockTimeMS": 100},
3170+
}
3171+
):
3172+
self.listener.reset()
3173+
with self.assertRaisesRegex(EncryptionError, "timed out"):
3174+
await self.client_encryption.create_data_key("local")
3175+
3176+
events = self.listener.started_events
3177+
self.assertEqual(1, len(events))
3178+
self.assertEqual("insert", events[0].command_name)
3179+
3180+
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#encrypt
3181+
async def test_02_encrypt(self):
3182+
data_key_id = await self.client_encryption.create_data_key("local")
3183+
self.assertEqual(4, data_key_id.subtype)
3184+
async with self.fail_point(
3185+
{
3186+
"mode": {"times": 1},
3187+
"data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100},
3188+
}
3189+
):
3190+
self.listener.reset()
3191+
with self.assertRaisesRegex(EncryptionError, "timed out"):
3192+
await self.client_encryption.encrypt(
3193+
"hello",
3194+
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
3195+
key_id=data_key_id,
3196+
)
3197+
3198+
events = self.listener.started_events
3199+
self.assertEqual(1, len(events))
3200+
self.assertEqual("find", events[0].command_name)
3201+
3202+
# https://github.com/mongodb/specifications/blob/master/source/client-side-operations-timeout/tests/README.md#decrypt
3203+
async def test_03_decrypt(self):
3204+
data_key_id = await self.client_encryption.create_data_key("local")
3205+
self.assertEqual(4, data_key_id.subtype)
3206+
3207+
encrypted = await self.client_encryption.encrypt(
3208+
"hello", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id
3209+
)
3210+
self.assertEqual(6, encrypted.subtype)
3211+
3212+
await self.key_vault_client.close()
3213+
self.key_vault_client = await self.async_rs_or_single_client(
3214+
timeoutMS=50, event_listeners=[self.listener]
3215+
)
3216+
await self.client_encryption.close()
3217+
self.client_encryption = self.create_client_encryption(
3218+
key_vault_namespace="keyvault.datakeys",
3219+
kms_providers={"local": {"key": self.LOCAL_MASTERKEY}},
3220+
key_vault_client=self.key_vault_client,
3221+
codec_options=OPTS,
3222+
)
3223+
3224+
async with self.fail_point(
3225+
{
3226+
"mode": {"times": 1},
3227+
"data": {"failCommands": ["find"], "blockConnection": True, "blockTimeMS": 100},
3228+
}
3229+
):
3230+
self.listener.reset()
3231+
with self.assertRaisesRegex(EncryptionError, "timed out"):
3232+
await self.client_encryption.decrypt(encrypted)
3233+
3234+
events = self.listener.started_events
3235+
self.assertEqual(1, len(events))
3236+
self.assertEqual("find", events[0].command_name)
3237+
3238+
31363239
if __name__ == "__main__":
31373240
unittest.main()

test/asynchronous/test_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from pymongo.asynchronous.cursor import AsyncCursor
4949
from pymongo.asynchronous.helpers import anext
5050
from pymongo.common import _MAX_END_SESSIONS
51-
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
51+
from pymongo.errors import ConfigurationError, InvalidOperation, NetworkTimeout, OperationFailure
5252
from pymongo.operations import IndexModel, InsertOne, UpdateOne
5353
from pymongo.read_concern import ReadConcern
5454

0 commit comments

Comments
 (0)