Skip to content

PYTHON-4864 - Create async version of SpecRunnerThread #2094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions test/asynchronous/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
client_knobs,
unittest,
)
from test.asynchronous.utils_spec_runner import SpecRunnerTask
from test.unified_format_shared import (
KMS_TLS_OPTS,
PLACEHOLDER_MAP,
Expand All @@ -58,7 +59,6 @@
snake_to_camel,
wait_until,
)
from test.utils_spec_runner import SpecRunnerThread
from test.version import Version
from typing import Any, Dict, List, Mapping, Optional

Expand Down Expand Up @@ -382,8 +382,8 @@ async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None:
return
elif entity_type == "thread":
name = spec["id"]
thread = SpecRunnerThread(name)
thread.start()
thread = SpecRunnerTask(name)
await thread.start()
self[name] = thread
return

Expand Down Expand Up @@ -1177,16 +1177,16 @@ def primary_changed() -> bool:

wait_until(primary_changed, "change primary", timeout=timeout)

def _testOperation_runOnThread(self, spec):
async def _testOperation_runOnThread(self, spec):
"""Run the 'runOnThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we running any async unified tests that use runOnThread?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SDAM unified tests are the only ones that use runOnThread. Those are currently slated to be converted to async, yes.


def _testOperation_waitForThread(self, spec):
async def _testOperation_waitForThread(self, spec):
"""Run the 'waitForThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.stop()
thread.join(10)
await thread.stop()
await thread.join(10)
if thread.exc:
raise thread.exc
self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"]))
Expand Down
44 changes: 33 additions & 11 deletions test/asynchronous/utils_spec_runner.py
Copy link
Contributor

@sleepyStick sleepyStick Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not loving the duplicated code between sync and async but i'm guessing its because the async version needs some more methods? If so, then i understand and can live with it >.<

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The async version doesn't implement threading.Thread, but it still needs to match the same API as the synchronous version. Let me see if I can reduce some of the duplication though.

Copy link
Contributor Author

@NoahStapp NoahStapp Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did some refactoring, much less duplication now. Great call-out!

Original file line number Diff line number Diff line change
Expand Up @@ -47,46 +47,68 @@
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from pymongo.lock import _async_create_condition, _async_create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
from pymongo.write_concern import WriteConcern

_IS_SYNC = False

if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object

class SpecRunnerThread(threading.Thread):

class SpecRunnerTask(PARENT):
def __init__(self, name):
super().__init__()
self.name = name
self.exc = None
self.daemon = True
self.cond = threading.Condition()
self.cond = _async_create_condition(_async_create_lock())
self.ops = []
self.stopped = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stopped is never set to True by this class.

self.task = None

if not _IS_SYNC:

async def start(self):
self.task = asyncio.create_task(self.run(), name=self.name)

def schedule(self, work):
async def join(self, timeout: float | None = 0): # type: ignore[override]
if self.task is not None:
await asyncio.wait([self.task], timeout=timeout)

def is_alive(self):
return not self.stopped

async def schedule(self, work):
self.ops.append(work)
with self.cond:
async with self.cond:
self.cond.notify()

def stop(self):
async def stop(self):
self.stopped = True
with self.cond:
async with self.cond:
self.cond.notify()

def run(self):
async def run(self):
while not self.stopped or self.ops:
if not self.ops:
with self.cond:
self.cond.wait(10)
async with self.cond:
if _IS_SYNC:
await self.cond.wait(10) # type: ignore[call-arg]
else:
await asyncio.wait_for(self.cond.wait(), timeout=10) # type: ignore[arg-type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use our _async_cond_wait compat function here to avoid the branching.

if self.ops:
try:
work = self.ops.pop(0)
work()
await work()
except Exception as exc:
self.exc = exc
self.stop()
await self.stop()


class AsyncSpecTestCreator:
Expand Down
2 changes: 1 addition & 1 deletion test/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@ def primary_changed() -> bool:
def _testOperation_runOnThread(self, spec):
"""Run the 'runOnThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))

def _testOperation_waitForThread(self, spec):
"""Run the 'waitForThread' operation."""
Expand Down
28 changes: 25 additions & 3 deletions test/utils_spec_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from gridfs import GridFSBucket
from gridfs.synchronous.grid_file import GridFSBucket
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from pymongo.lock import _create_condition, _create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
Expand All @@ -54,16 +55,34 @@

_IS_SYNC = True

if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object

class SpecRunnerThread(threading.Thread):

class SpecRunnerThread(PARENT):
def __init__(self, name):
super().__init__()
self.name = name
self.exc = None
self.daemon = True
self.cond = threading.Condition()
self.cond = _create_condition(_create_lock())
self.ops = []
self.stopped = False
self.task = None

if not _IS_SYNC:

def start(self):
self.task = asyncio.create_task(self.run(), name=self.name)

def join(self, timeout: float | None = 0): # type: ignore[override]
if self.task is not None:
asyncio.wait([self.task], timeout=timeout)

def is_alive(self):
return not self.stopped

def schedule(self, work):
self.ops.append(work)
Expand All @@ -79,7 +98,10 @@ def run(self):
while not self.stopped or self.ops:
if not self.ops:
with self.cond:
self.cond.wait(10)
if _IS_SYNC:
self.cond.wait(10) # type: ignore[call-arg]
else:
asyncio.wait_for(self.cond.wait(), timeout=10) # type: ignore[arg-type]
if self.ops:
try:
work = self.ops.pop(0)
Expand Down
1 change: 1 addition & 0 deletions tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
"_async_create_lock": "_create_lock",
"_async_create_condition": "_create_condition",
"_async_cond_wait": "_cond_wait",
"SpecRunnerTask": "SpecRunnerThread",
}

docstring_replacements: dict[tuple[str, str], str] = {
Expand Down
Loading