From 69d657ed60053623a8193d4cb0194ec2c26492fd Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 29 Jan 2025 15:08:18 -0500 Subject: [PATCH 1/7] PYTHON-4864 - Create async version of SpecRunnerThread --- test/asynchronous/unified_format.py | 23 ++++-- test/asynchronous/utils_spec_runner.py | 109 +++++++++++++++++-------- test/unified_format.py | 9 +- test/utils_spec_runner.py | 109 +++++++++++++++++-------- tools/synchro.py | 2 + 5 files changed, 177 insertions(+), 75 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 52d964eb3e..6963945b46 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -35,6 +35,7 @@ client_knobs, unittest, ) +from test.asynchronous.utils_spec_runner import SpecRunnerTask from test.unified_format_shared import ( KMS_TLS_OPTS, PLACEHOLDER_MAP, @@ -58,7 +59,6 @@ snake_to_camel, wait_until, ) -from test.utils_spec_runner import SpecRunnerThread from test.version import Version from typing import Any, Dict, List, Mapping, Optional @@ -382,8 +382,8 @@ async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None: return elif entity_type == "thread": name = spec["id"] - thread = SpecRunnerThread(name) - thread.start() + thread = SpecRunnerTask(name) + await thread.start() self[name] = thread return @@ -1177,16 +1177,23 @@ def primary_changed() -> bool: wait_until(primary_changed, "change primary", timeout=timeout) - def _testOperation_runOnThread(self, spec): + async def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + if _IS_SYNC: + await thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + else: + + async def op(): + await self.run_entity_operation(spec["operation"]) + + await thread.schedule(op) - def _testOperation_waitForThread(self, spec): + async def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.stop() - thread.join(10) + await thread.stop() + await thread.join(10) if thread.exc: raise thread.exc self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"])) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index b79e5258b5..e59ecd9b94 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -54,39 +54,82 @@ _IS_SYNC = False - -class SpecRunnerThread(threading.Thread): - def __init__(self, name): - super().__init__() - self.name = name - self.exc = None - self.daemon = True - self.cond = threading.Condition() - self.ops = [] - self.stopped = False - - def schedule(self, work): - self.ops.append(work) - with self.cond: - self.cond.notify() - - def stop(self): - self.stopped = True - with self.cond: - self.cond.notify() - - def run(self): - while not self.stopped or self.ops: - if not self.ops: - with self.cond: - self.cond.wait(10) - if self.ops: - try: - work = self.ops.pop(0) - work() - except Exception as exc: - self.exc = exc - self.stop() +if _IS_SYNC: + + class SpecRunnerThread(threading.Thread): + def __init__(self, name): + super().__init__() + self.name = name + self.exc = None + self.daemon = True + self.cond = threading.Condition() + self.ops = [] + self.stopped = False + + def schedule(self, work): + self.ops.append(work) + with self.cond: + self.cond.notify() + + def stop(self): + self.stopped = True + with self.cond: + self.cond.notify() + + def run(self): + while not self.stopped or self.ops: + if not self.ops: + with self.cond: + self.cond.wait(10) + if self.ops: + try: + work = self.ops.pop(0) + work() + except Exception as exc: + self.exc = exc + self.stop() +else: + + class SpecRunnerTask: + def __init__(self, name): + self.name = name + self.exc = None + self.cond = asyncio.Condition() + self.ops = [] + self.stopped = False + self.task = None + + async def schedule(self, work): + self.ops.append(work) + async with self.cond: + self.cond.notify() + + async def stop(self): + self.stopped = True + async with self.cond: + self.cond.notify() + + async def start(self): + self.task = asyncio.create_task(self.run(), name=self.name) + + async def join(self, timeout: int = 0): + await asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + async def run(self): + while not self.stopped or self.ops: + if not self.ops: + async with self.cond: + await asyncio.wait_for(self.cond.wait(), timeout=10) + if self.ops: + try: + work = self.ops.pop(0) + await work() + except Exception as exc: + self.exc = exc + await self.stop() class AsyncSpecTestCreator: diff --git a/test/unified_format.py b/test/unified_format.py index 372eb8abba..28369a5e87 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1167,7 +1167,14 @@ def primary_changed() -> bool: def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + if _IS_SYNC: + thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + else: + + def op(): + self.run_entity_operation(spec["operation"]) + + thread.schedule(op) def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 4508502cd0..4b24c5c2e8 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -54,39 +54,82 @@ _IS_SYNC = True - -class SpecRunnerThread(threading.Thread): - def __init__(self, name): - super().__init__() - self.name = name - self.exc = None - self.daemon = True - self.cond = threading.Condition() - self.ops = [] - self.stopped = False - - def schedule(self, work): - self.ops.append(work) - with self.cond: - self.cond.notify() - - def stop(self): - self.stopped = True - with self.cond: - self.cond.notify() - - def run(self): - while not self.stopped or self.ops: - if not self.ops: - with self.cond: - self.cond.wait(10) - if self.ops: - try: - work = self.ops.pop(0) - work() - except Exception as exc: - self.exc = exc - self.stop() +if _IS_SYNC: + + class SpecRunnerThread(threading.Thread): + def __init__(self, name): + super().__init__() + self.name = name + self.exc = None + self.daemon = True + self.cond = threading.Condition() + self.ops = [] + self.stopped = False + + def schedule(self, work): + self.ops.append(work) + with self.cond: + self.cond.notify() + + def stop(self): + self.stopped = True + with self.cond: + self.cond.notify() + + def run(self): + while not self.stopped or self.ops: + if not self.ops: + with self.cond: + self.cond.wait(10) + if self.ops: + try: + work = self.ops.pop(0) + work() + except Exception as exc: + self.exc = exc + self.stop() +else: + + class SpecRunnerThread: + def __init__(self, name): + self.name = name + self.exc = None + self.cond = asyncio.Condition() + self.ops = [] + self.stopped = False + self.task = None + + def schedule(self, work): + self.ops.append(work) + with self.cond: + self.cond.notify() + + def stop(self): + self.stopped = True + with self.cond: + self.cond.notify() + + def start(self): + self.task = asyncio.create_task(self.run(), name=self.name) + + def join(self, timeout: int = 0): + asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + def run(self): + while not self.stopped or self.ops: + if not self.ops: + with self.cond: + asyncio.wait_for(self.cond.wait(), timeout=10) + if self.ops: + try: + work = self.ops.pop(0) + work() + except Exception as exc: + self.exc = exc + self.stop() class SpecTestCreator: diff --git a/tools/synchro.py b/tools/synchro.py index 897e5e8018..6444a06922 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -119,6 +119,8 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "AsyncDummyMonitor": "DummyMonitor", + "SpecRunnerTask": "SpecRunnerThread", } docstring_replacements: dict[tuple[str, str], str] = { From 05bb77c3baeade1fa6b449c133b7503da222b33c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 29 Jan 2025 15:20:59 -0500 Subject: [PATCH 2/7] Fix typing --- test/asynchronous/utils_spec_runner.py | 3 ++- test/utils_spec_runner.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index e59ecd9b94..9bcc5df57a 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -113,7 +113,8 @@ async def start(self): self.task = asyncio.create_task(self.run(), name=self.name) async def join(self, timeout: int = 0): - await asyncio.wait([self.task], timeout=timeout) + if self.task is not None: + await asyncio.wait([self.task], timeout=timeout) def is_alive(self): return not self.stopped diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 4b24c5c2e8..7ef34b7138 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -113,7 +113,8 @@ def start(self): self.task = asyncio.create_task(self.run(), name=self.name) def join(self, timeout: int = 0): - asyncio.wait([self.task], timeout=timeout) + if self.task is not None: + asyncio.wait([self.task], timeout=timeout) def is_alive(self): return not self.stopped From 1687553ab446064a8b56ed01f4c73b83759c6b28 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 29 Jan 2025 16:56:24 -0500 Subject: [PATCH 3/7] Consolidate SpecRunnerTask classes --- test/asynchronous/unified_format.py | 9 +-- test/asynchronous/utils_spec_runner.py | 106 ++++++++++--------------- test/unified_format.py | 9 +-- test/utils_spec_runner.py | 106 ++++++++++--------------- tools/synchro.py | 1 - 5 files changed, 86 insertions(+), 145 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 6963945b46..a7a6364497 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1180,14 +1180,7 @@ def primary_changed() -> bool: async def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - if _IS_SYNC: - await thread.schedule(lambda: self.run_entity_operation(spec["operation"])) - else: - - async def op(): - await self.run_entity_operation(spec["operation"]) - - await thread.schedule(op) + await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"])) async def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 9bcc5df57a..e031f0adcc 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -47,6 +47,7 @@ from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from pymongo.lock import _async_create_condition, _async_create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -55,82 +56,59 @@ _IS_SYNC = False if _IS_SYNC: - - class SpecRunnerThread(threading.Thread): - def __init__(self, name): - super().__init__() - self.name = name - self.exc = None - self.daemon = True - self.cond = threading.Condition() - self.ops = [] - self.stopped = False - - def schedule(self, work): - self.ops.append(work) - with self.cond: - self.cond.notify() - - def stop(self): - self.stopped = True - with self.cond: - self.cond.notify() - - def run(self): - while not self.stopped or self.ops: - if not self.ops: - with self.cond: - self.cond.wait(10) - if self.ops: - try: - work = self.ops.pop(0) - work() - except Exception as exc: - self.exc = exc - self.stop() + PARENT = threading.Thread else: + PARENT = object + - class SpecRunnerTask: - def __init__(self, name): - self.name = name - self.exc = None - self.cond = asyncio.Condition() - self.ops = [] - self.stopped = False - self.task = None - - async def schedule(self, work): - self.ops.append(work) - async with self.cond: - self.cond.notify() - - async def stop(self): - self.stopped = True - async with self.cond: - self.cond.notify() +class SpecRunnerTask(PARENT): + def __init__(self, name): + super().__init__() + self.name = name + self.exc = None + self.daemon = True + self.cond = _async_create_condition(_async_create_lock()) + self.ops = [] + self.stopped = False + self.task = None + + if not _IS_SYNC: async def start(self): self.task = asyncio.create_task(self.run(), name=self.name) - async def join(self, timeout: int = 0): + async def join(self, timeout: float | None = 0): # type: ignore[override] if self.task is not None: await asyncio.wait([self.task], timeout=timeout) def is_alive(self): return not self.stopped - async def run(self): - while not self.stopped or self.ops: - if not self.ops: - async with self.cond: - await asyncio.wait_for(self.cond.wait(), timeout=10) - if self.ops: - try: - work = self.ops.pop(0) - await work() - except Exception as exc: - self.exc = exc - await self.stop() + async def schedule(self, work): + self.ops.append(work) + async with self.cond: + self.cond.notify() + + async def stop(self): + self.stopped = True + async with self.cond: + self.cond.notify() + + async def run(self): + while not self.stopped or self.ops: + if not self.ops: + async with self.cond: + if _IS_SYNC: + await self.cond.wait(10) # type: ignore[call-arg] + else: + await asyncio.wait_for(self.cond.wait(), timeout=10) # type: ignore[arg-type] + if self.ops: + try: + work = self.ops.pop(0) + await work() + except Exception as exc: + self.exc = exc + await self.stop() class AsyncSpecTestCreator: diff --git a/test/unified_format.py b/test/unified_format.py index 28369a5e87..84f1553c53 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1167,14 +1167,7 @@ def primary_changed() -> bool: def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - if _IS_SYNC: - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) - else: - - def op(): - self.run_entity_operation(spec["operation"]) - - thread.schedule(op) + thread.schedule(functools.partial(self.run_entity_operation, spec["operation"])) def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 7ef34b7138..64c0232d43 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -44,6 +44,7 @@ from gridfs import GridFSBucket from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from pymongo.lock import _create_condition, _create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -55,82 +56,59 @@ _IS_SYNC = True if _IS_SYNC: - - class SpecRunnerThread(threading.Thread): - def __init__(self, name): - super().__init__() - self.name = name - self.exc = None - self.daemon = True - self.cond = threading.Condition() - self.ops = [] - self.stopped = False - - def schedule(self, work): - self.ops.append(work) - with self.cond: - self.cond.notify() - - def stop(self): - self.stopped = True - with self.cond: - self.cond.notify() - - def run(self): - while not self.stopped or self.ops: - if not self.ops: - with self.cond: - self.cond.wait(10) - if self.ops: - try: - work = self.ops.pop(0) - work() - except Exception as exc: - self.exc = exc - self.stop() + PARENT = threading.Thread else: + PARENT = object + - class SpecRunnerThread: - def __init__(self, name): - self.name = name - self.exc = None - self.cond = asyncio.Condition() - self.ops = [] - self.stopped = False - self.task = None - - def schedule(self, work): - self.ops.append(work) - with self.cond: - self.cond.notify() - - def stop(self): - self.stopped = True - with self.cond: - self.cond.notify() +class SpecRunnerThread(PARENT): + def __init__(self, name): + super().__init__() + self.name = name + self.exc = None + self.daemon = True + self.cond = _create_condition(_create_lock()) + self.ops = [] + self.stopped = False + self.task = None + + if not _IS_SYNC: def start(self): self.task = asyncio.create_task(self.run(), name=self.name) - def join(self, timeout: int = 0): + def join(self, timeout: float | None = 0): # type: ignore[override] if self.task is not None: asyncio.wait([self.task], timeout=timeout) def is_alive(self): return not self.stopped - def run(self): - while not self.stopped or self.ops: - if not self.ops: - with self.cond: - asyncio.wait_for(self.cond.wait(), timeout=10) - if self.ops: - try: - work = self.ops.pop(0) - work() - except Exception as exc: - self.exc = exc - self.stop() + def schedule(self, work): + self.ops.append(work) + with self.cond: + self.cond.notify() + + def stop(self): + self.stopped = True + with self.cond: + self.cond.notify() + + def run(self): + while not self.stopped or self.ops: + if not self.ops: + with self.cond: + if _IS_SYNC: + self.cond.wait(10) # type: ignore[call-arg] + else: + asyncio.wait_for(self.cond.wait(), timeout=10) # type: ignore[arg-type] + if self.ops: + try: + work = self.ops.pop(0) + work() + except Exception as exc: + self.exc = exc + self.stop() class SpecTestCreator: diff --git a/tools/synchro.py b/tools/synchro.py index 6444a06922..833ebb0330 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -119,7 +119,6 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", - "AsyncDummyMonitor": "DummyMonitor", "SpecRunnerTask": "SpecRunnerThread", } From e3f2f26630ed2e14f7814447c29eb83d18b59000 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 31 Jan 2025 13:18:14 -0500 Subject: [PATCH 4/7] Consolidate condition waits --- test/asynchronous/utils_spec_runner.py | 7 ++----- test/utils_spec_runner.py | 7 ++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index e031f0adcc..104ea8d5be 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -47,7 +47,7 @@ from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError -from pymongo.lock import _async_create_condition, _async_create_lock +from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -98,10 +98,7 @@ async def run(self): while not self.stopped or self.ops: if not self.ops: async with self.cond: - if _IS_SYNC: - await self.cond.wait(10) # type: ignore[call-arg] - else: - await asyncio.wait_for(self.cond.wait(), timeout=10) # type: ignore[arg-type] + await _async_cond_wait(self.cond, 10) if self.ops: try: work = self.ops.pop(0) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 64c0232d43..c2bae05ba5 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -44,7 +44,7 @@ from gridfs import GridFSBucket from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError -from pymongo.lock import _create_condition, _create_lock +from pymongo.lock import _cond_wait, _create_condition, _create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -98,10 +98,7 @@ def run(self): while not self.stopped or self.ops: if not self.ops: with self.cond: - if _IS_SYNC: - self.cond.wait(10) # type: ignore[call-arg] - else: - asyncio.wait_for(self.cond.wait(), timeout=10) # type: ignore[arg-type] + _cond_wait(self.cond, 10) if self.ops: try: work = self.ops.pop(0) From 3a678d93224b809a1b5a581a3f6fde23041b268c Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 31 Jan 2025 13:43:13 -0500 Subject: [PATCH 5/7] Abstract Thread/Task wrapper for general use --- test/asynchronous/utils_spec_runner.py | 20 +++++++++++++------- test/utils_spec_runner.py | 20 +++++++++++++------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 104ea8d5be..aa69c528ee 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -61,14 +61,11 @@ PARENT = object -class SpecRunnerTask(PARENT): - def __init__(self, name): - super().__init__() +class ConcurrentRunner(PARENT): + def __init__(self, name, *args, **kwargs): + if _IS_SYNC: + super().__init__(*args, **kwargs) self.name = name - self.exc = None - self.daemon = True - self.cond = _async_create_condition(_async_create_lock()) - self.ops = [] self.stopped = False self.task = None @@ -84,6 +81,15 @@ async def join(self, timeout: float | None = 0): # type: ignore[override] def is_alive(self): return not self.stopped + +class SpecRunnerTask(ConcurrentRunner): + def __init__(self, name): + super().__init__(name) + self.exc = None + self.daemon = True + self.cond = _async_create_condition(_async_create_lock()) + self.ops = [] + async def schedule(self, work): self.ops.append(work) async with self.cond: diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index c2bae05ba5..c78ce3e23a 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -61,14 +61,11 @@ PARENT = object -class SpecRunnerThread(PARENT): - def __init__(self, name): - super().__init__() +class ConcurrentRunner(PARENT): + def __init__(self, name, *args, **kwargs): + if _IS_SYNC: + super().__init__(*args, **kwargs) self.name = name - self.exc = None - self.daemon = True - self.cond = _create_condition(_create_lock()) - self.ops = [] self.stopped = False self.task = None @@ -84,6 +81,15 @@ def join(self, timeout: float | None = 0): # type: ignore[override] def is_alive(self): return not self.stopped + +class SpecRunnerThread(ConcurrentRunner): + def __init__(self, name): + super().__init__(name) + self.exc = None + self.daemon = True + self.cond = _create_condition(_create_lock()) + self.ops = [] + def schedule(self, work): self.ops.append(work) with self.cond: From 9fddf512a2d21c2d03539703110284f7e3cfe830 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 31 Jan 2025 13:54:16 -0500 Subject: [PATCH 6/7] Add default run method --- test/asynchronous/utils_spec_runner.py | 6 ++++++ test/utils_spec_runner.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index aa69c528ee..6e25a326db 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -68,6 +68,8 @@ def __init__(self, name, *args, **kwargs): self.name = name self.stopped = False self.task = None + if "target" in kwargs: + self.target = kwargs["target"] if not _IS_SYNC: @@ -81,6 +83,10 @@ async def join(self, timeout: float | None = 0): # type: ignore[override] def is_alive(self): return not self.stopped + async def run(self): + if self.target: + await self.target() + class SpecRunnerTask(ConcurrentRunner): def __init__(self, name): diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index c78ce3e23a..60f3f66a02 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -68,6 +68,8 @@ def __init__(self, name, *args, **kwargs): self.name = name self.stopped = False self.task = None + if "target" in kwargs: + self.target = kwargs["target"] if not _IS_SYNC: @@ -81,6 +83,10 @@ def join(self, timeout: float | None = 0): # type: ignore[override] def is_alive(self): return not self.stopped + def run(self): + if self.target: + self.target() + class SpecRunnerThread(ConcurrentRunner): def __init__(self, name): From 25e31db4e8da2e454755f60e4d30d30f32c38c40 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 3 Feb 2025 15:53:24 -0500 Subject: [PATCH 7/7] Move ConcurrentRunner into helpers.py --- test/asynchronous/helpers.py | 37 ++++++++++++++++++++++++++ test/asynchronous/utils_spec_runner.py | 34 +---------------------- test/helpers.py | 37 ++++++++++++++++++++++++++ test/utils_spec_runner.py | 34 +---------------------- 4 files changed, 76 insertions(+), 66 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index b5fc5d8ac4..7758f281e1 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -15,6 +15,7 @@ """Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" from __future__ import annotations +import asyncio import base64 import gc import multiprocessing @@ -30,6 +31,8 @@ import warnings from asyncio import iscoroutinefunction +from pymongo._asyncio_task import create_task + try: import ipaddress @@ -369,3 +372,37 @@ def disable(self): os.environ.pop("SSL_CERT_FILE") else: os.environ["SSL_CERT_FILE"] = self.original_certs + + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class ConcurrentRunner(PARENT): + def __init__(self, name, *args, **kwargs): + if _IS_SYNC: + super().__init__(*args, **kwargs) + self.name = name + self.stopped = False + self.task = None + if "target" in kwargs: + self.target = kwargs["target"] + + if not _IS_SYNC: + + async def start(self): + self.task = create_task(self.run(), name=self.name) + + async def join(self, timeout: float | None = 0): # type: ignore[override] + if self.task is not None: + await asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + async def run(self): + if self.target: + await self.target() + self.stopped = True diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 6e25a326db..d103374313 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -18,11 +18,11 @@ import asyncio import functools import os -import threading import unittest from asyncio import iscoroutinefunction from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs +from test.asynchronous.helpers import ConcurrentRunner from test.utils import ( CMAPListener, CompareType, @@ -55,38 +55,6 @@ _IS_SYNC = False -if _IS_SYNC: - PARENT = threading.Thread -else: - PARENT = object - - -class ConcurrentRunner(PARENT): - def __init__(self, name, *args, **kwargs): - if _IS_SYNC: - super().__init__(*args, **kwargs) - self.name = name - self.stopped = False - self.task = None - if "target" in kwargs: - self.target = kwargs["target"] - - if not _IS_SYNC: - - async def start(self): - self.task = asyncio.create_task(self.run(), name=self.name) - - async def join(self, timeout: float | None = 0): # type: ignore[override] - if self.task is not None: - await asyncio.wait([self.task], timeout=timeout) - - def is_alive(self): - return not self.stopped - - async def run(self): - if self.target: - await self.target() - class SpecRunnerTask(ConcurrentRunner): def __init__(self, name): diff --git a/test/helpers.py b/test/helpers.py index 11d5ab0374..bd9e23bba4 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -15,6 +15,7 @@ """Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" from __future__ import annotations +import asyncio import base64 import gc import multiprocessing @@ -30,6 +31,8 @@ import warnings from asyncio import iscoroutinefunction +from pymongo._asyncio_task import create_task + try: import ipaddress @@ -369,3 +372,37 @@ def disable(self): os.environ.pop("SSL_CERT_FILE") else: os.environ["SSL_CERT_FILE"] = self.original_certs + + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class ConcurrentRunner(PARENT): + def __init__(self, name, *args, **kwargs): + if _IS_SYNC: + super().__init__(*args, **kwargs) + self.name = name + self.stopped = False + self.task = None + if "target" in kwargs: + self.target = kwargs["target"] + + if not _IS_SYNC: + + def start(self): + self.task = create_task(self.run(), name=self.name) + + def join(self, timeout: float | None = 0): # type: ignore[override] + if self.task is not None: + asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + def run(self): + if self.target: + self.target() + self.stopped = True diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 60f3f66a02..6a62112afb 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -18,11 +18,11 @@ import asyncio import functools import os -import threading import unittest from asyncio import iscoroutinefunction from collections import abc from test import IntegrationTest, client_context, client_knobs +from test.helpers import ConcurrentRunner from test.utils import ( CMAPListener, CompareType, @@ -55,38 +55,6 @@ _IS_SYNC = True -if _IS_SYNC: - PARENT = threading.Thread -else: - PARENT = object - - -class ConcurrentRunner(PARENT): - def __init__(self, name, *args, **kwargs): - if _IS_SYNC: - super().__init__(*args, **kwargs) - self.name = name - self.stopped = False - self.task = None - if "target" in kwargs: - self.target = kwargs["target"] - - if not _IS_SYNC: - - def start(self): - self.task = asyncio.create_task(self.run(), name=self.name) - - def join(self, timeout: float | None = 0): # type: ignore[override] - if self.task is not None: - asyncio.wait([self.task], timeout=timeout) - - def is_alive(self): - return not self.stopped - - def run(self): - if self.target: - self.target() - class SpecRunnerThread(ConcurrentRunner): def __init__(self, name):