Skip to content

Commit 73967b4

Browse files
committed
New ensure_async util function
1 parent 3f834ca commit 73967b4

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

src/reactpy_django/forms/components.py

+5-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import asyncio
43
from pathlib import Path
54
from typing import TYPE_CHECKING, Any, Callable, Union, cast
65
from uuid import uuid4
@@ -20,6 +19,7 @@
2019
)
2120
from reactpy_django.forms.utils import convert_boolean_fields, convert_multiple_choice_fields
2221
from reactpy_django.types import AsyncFormEvent, FormEventData, SyncFormEvent
22+
from reactpy_django.utils import ensure_async
2323

2424
if TYPE_CHECKING:
2525
from collections.abc import Sequence
@@ -80,15 +80,9 @@ async def render_form():
8080
await database_sync_to_async(initialized_form.full_clean)()
8181
success = not initialized_form.errors.as_data()
8282
if success and on_success:
83-
if asyncio.iscoroutinefunction(on_success):
84-
await on_success(form_event)
85-
else:
86-
on_success(form_event)
83+
await ensure_async(on_success)(form_event)
8784
if not success and on_error:
88-
if asyncio.iscoroutinefunction(on_error):
89-
await on_error(form_event)
90-
else:
91-
on_error(form_event)
85+
await ensure_async(on_error)(form_event)
9286
if success and auto_save and isinstance(initialized_form, ModelForm):
9387
await database_sync_to_async(initialized_form.save)()
9488
set_submitted_data(None)
@@ -109,21 +103,15 @@ async def on_submit_callback(new_data: dict[str, Any]):
109103
new_form_event = FormEventData(
110104
form=initialized_form, submitted_data=new_data, set_submitted_data=set_submitted_data
111105
)
112-
if asyncio.iscoroutinefunction(on_receive_data):
113-
await on_receive_data(new_form_event)
114-
else:
115-
on_receive_data(new_form_event)
106+
await ensure_async(on_receive_data)(new_form_event)
116107

117108
if submitted_data != new_data:
118109
set_submitted_data(new_data)
119110

120111
async def _on_change(_event):
121112
"""Event that exist solely to allow the user to detect form changes."""
122113
if on_change:
123-
if asyncio.iscoroutinefunction(on_change):
124-
await on_change(form_event)
125-
else:
126-
on_change(form_event)
114+
await ensure_async(on_change)(form_event)
127115

128116
if not rendered_form:
129117
return None

src/reactpy_django/utils.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from fnmatch import fnmatch
1414
from importlib import import_module
1515
from pathlib import Path
16-
from typing import TYPE_CHECKING, Any, Callable
16+
from typing import TYPE_CHECKING, Any, Awaitable, Callable
1717
from uuid import UUID, uuid4
1818

1919
import dill
@@ -549,3 +549,20 @@ async def __aiter__(self):
549549
finally:
550550
if file_opened:
551551
file_handle.close()
552+
553+
554+
def ensure_async(
555+
func: Callable[FuncParams, Inferred], *, thread_sensitive: bool = True
556+
) -> Callable[FuncParams, Awaitable[Inferred]]:
557+
"""Ensure the provided function is always an async coroutine. If the provided function is
558+
not async, it will be adapted."""
559+
560+
@wraps(func)
561+
def wrapper(*args, **kwargs):
562+
return (
563+
func(*args, **kwargs)
564+
if inspect.iscoroutinefunction(func)
565+
else database_sync_to_async(func, thread_sensitive=thread_sensitive)(*args, **kwargs)
566+
)
567+
568+
return wrapper

0 commit comments

Comments
 (0)