Skip to content

Commit eff3718

Browse files
committed
always poll async
if we poll sync then the server might not ever get a hold of the event loop. this was causing a problem for the sanic server in a test
1 parent 38d0dd4 commit eff3718

File tree

10 files changed

+87
-81
lines changed

10 files changed

+87
-81
lines changed

requirements/pkg-deps.txt

-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,3 @@ jsonpatch >=1.32
55
fastjsonschema >=2.14.5
66
requests >=2
77
colorlog >=6
8-
werkzeug >=2

src/idom/server/tornado.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .utils import CLIENT_BUILD_DIR, safe_client_build_dir_path
2727

2828

29-
ConnectionContext: type[Context[HTTPServerRequest | None]] = create_context(
29+
ConnectionContext: type[Context[Connection | None]] = create_context(
3030
None, "ConnectionContext"
3131
)
3232

@@ -207,7 +207,7 @@ class ModelStreamHandler(WebSocketHandler):
207207
def initialize(self, component_constructor: ComponentConstructor) -> None:
208208
self._component_constructor = component_constructor
209209

210-
async def open(self, path: str = "") -> None:
210+
async def open(self, path: str = "", *args: Any, **kwargs: Any) -> None:
211211
message_queue: "AsyncQueue[str]" = AsyncQueue()
212212

213213
async def send(value: VdomJsonPatch) -> None:

src/idom/server/utils.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22

33
import asyncio
44
import logging
5+
import os
56
import socket
67
from contextlib import closing
78
from importlib import import_module
89
from pathlib import Path
910
from typing import Any, Iterator
1011

11-
from werkzeug.security import safe_join
12-
1312
import idom
1413
from idom.config import IDOM_WEB_MODULES_DIR
1514
from idom.types import RootComponentConstructor
@@ -72,11 +71,16 @@ def safe_web_modules_dir_path(path: str) -> Path:
7271
return traversal_safe_path(IDOM_WEB_MODULES_DIR.current, *path.split("/"))
7372

7473

75-
def traversal_safe_path(root: Path, *unsafe_parts: str | Path) -> Path:
76-
"""Sanitize user given path using ``werkzeug.security.safe_join``"""
77-
path = safe_join(str(root.resolve()), *unsafe_parts) # type: ignore
78-
if path is None:
79-
raise ValueError("Unsafe path") # pragma: no cover
74+
def traversal_safe_path(root: Path, *unsafe: str | Path) -> Path:
75+
"""Raise a ``ValueError`` if the ``unsafe`` path resolves outside the root dir."""
76+
root = root.resolve()
77+
# resolve relative paths and symlinks
78+
path = root.joinpath(*unsafe).resolve()
79+
80+
if os.path.commonprefix([root, path]) != str(root):
81+
# If the common prefix is not root directory we resolved outside the root dir
82+
raise ValueError("Unsafe path")
83+
8084
return Path(path)
8185

8286

src/idom/testing/common.py

+42-52
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
import inspect
35
import shutil
46
import time
57
from functools import wraps
@@ -26,9 +28,7 @@ def clear_idom_web_modules_dir() -> None:
2628
_RC = TypeVar("_RC", covariant=True)
2729

2830

29-
class _UntilFunc(Protocol[_RC]):
30-
def __call__(self, condition: Callable[[_RC], bool], timeout: float = ...) -> Any:
31-
...
31+
_DEFAULT_POLL_DELAY = 0.1
3232

3333

3434
class poll(Generic[_R]): # noqa: N801
@@ -40,64 +40,54 @@ def __init__(
4040
*args: _P.args,
4141
**kwargs: _P.kwargs,
4242
) -> None:
43-
self.until: _UntilFunc[_R]
44-
"""Check that the coroutines result meets a condition within the timeout"""
43+
coro: Callable[_P, Awaitable[_R]]
44+
if not inspect.iscoroutinefunction(function):
45+
46+
async def coro(*args: _P.args, **kwargs: _P.kwargs) -> _R:
47+
return cast(_R, function(*args, **kwargs))
4548

46-
if iscoroutinefunction(function):
47-
coro_function = cast(Callable[_P, Awaitable[_R]], function)
48-
49-
async def coro_until(
50-
condition: Callable[[_R], bool],
51-
timeout: float = IDOM_TESTING_DEFAULT_TIMEOUT.current,
52-
) -> None:
53-
started_at = time.time()
54-
while True:
55-
result = await coro_function(*args, **kwargs)
56-
if condition(result):
57-
break
58-
elif (time.time() - started_at) > timeout: # pragma: no cover
59-
raise TimeoutError(
60-
f"Condition not met within {timeout} "
61-
f"seconds - last value was {result!r}"
62-
)
63-
64-
self.until = coro_until
6549
else:
66-
sync_function = cast(Callable[_P, _R], function)
67-
68-
def sync_until(
69-
condition: Callable[[_R], bool] | Any,
70-
timeout: float = IDOM_TESTING_DEFAULT_TIMEOUT.current,
71-
) -> None:
72-
started_at = time.time()
73-
while True:
74-
result = sync_function(*args, **kwargs)
75-
if condition(result):
76-
break
77-
elif (time.time() - started_at) > timeout: # pragma: no cover
78-
raise TimeoutError(
79-
f"Condition not met within {timeout} "
80-
f"seconds - last value was {result!r}"
81-
)
82-
83-
self.until = sync_until
84-
85-
def until_is(
50+
coro = cast(Callable[_P, Awaitable[_R]], function)
51+
self._func = coro
52+
self._args = args
53+
self._kwargs = kwargs
54+
55+
async def until(
8656
self,
87-
right: Any,
57+
condition: Callable[[_R], bool],
8858
timeout: float = IDOM_TESTING_DEFAULT_TIMEOUT.current,
89-
) -> Any:
59+
delay: float = _DEFAULT_POLL_DELAY,
60+
) -> None:
61+
"""Check that the coroutines result meets a condition within the timeout"""
62+
started_at = time.time()
63+
while True:
64+
await asyncio.sleep(delay)
65+
result = await self._func(*self._args, **self._kwargs)
66+
if condition(result):
67+
break
68+
elif (time.time() - started_at) > timeout: # pragma: no cover
69+
raise TimeoutError(
70+
f"Condition not met within {timeout} "
71+
f"seconds - last value was {result!r}"
72+
)
73+
74+
async def until_is(
75+
self,
76+
right: _R,
77+
timeout: float = IDOM_TESTING_DEFAULT_TIMEOUT.current,
78+
delay: float = _DEFAULT_POLL_DELAY,
79+
) -> None:
9080
"""Wait until the result is identical to the given value"""
91-
return self.until(lambda left: left is right, timeout)
81+
return await self.until(lambda left: left is right, timeout, delay)
9282

93-
def until_equals(
83+
async def until_equals(
9484
self,
95-
right: Any,
85+
right: _R,
9686
timeout: float = IDOM_TESTING_DEFAULT_TIMEOUT.current,
97-
) -> Any:
87+
delay: float = _DEFAULT_POLL_DELAY,
88+
) -> None:
9889
"""Wait until the result is equal to the given value"""
99-
# not really sure why I need a type ignore comment here
100-
return self.until(lambda left: left == right, timeout) # type: ignore
90+
return await self.until(lambda left: left == right, timeout, delay)
10191

10292

10393
class HookCatcher:

tests/test_core/test_events.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -221,4 +221,4 @@ def outer_click_is_not_triggered(event):
221221
inner = await display.page.wait_for_selector("#inner")
222222
await inner.click()
223223

224-
poll(lambda: clicked.current).until_is(True)
224+
await poll(lambda: clicked.current).until_is(True)

tests/test_core/test_hooks.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -204,18 +204,18 @@ def TestComponent():
204204

205205
await client_r_1_button.click()
206206

207-
poll_event_count.until_equals(1)
208-
poll_render_count.until_equals(1)
207+
await poll_event_count.until_equals(1)
208+
await poll_render_count.until_equals(1)
209209

210210
await client_r_2_button.click()
211211

212-
poll_event_count.until_equals(2)
213-
poll_render_count.until_equals(2)
212+
await poll_event_count.until_equals(2)
213+
await poll_render_count.until_equals(2)
214214

215215
await client_r_2_button.click()
216216

217-
poll_event_count.until_equals(3)
218-
poll_render_count.until_equals(2)
217+
await poll_event_count.until_equals(3)
218+
await poll_render_count.until_equals(2)
219219

220220

221221
async def test_simple_input_with_use_state(display: DisplayFixture):

tests/test_server/test_common.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ def ShowScope():
8080

8181
async def test_use_location(display: DisplayFixture):
8282
location = idom.Ref()
83-
poll_location = poll(lambda: location.current)
83+
84+
@poll
85+
async def poll_location():
86+
"""This needs to be async to allow the server to respond"""
87+
return location.current
8488

8589
@idom.component
8690
def ShowRoute():
@@ -89,7 +93,7 @@ def ShowRoute():
8993

9094
await display.show(ShowRoute)
9195

92-
poll_location.until_equals(Location("/", ""))
96+
await poll_location.until_equals(Location("/", ""))
9397

9498
for loc in [
9599
Location("/something"),
@@ -100,4 +104,4 @@ def ShowRoute():
100104
Location("/another/something/file.txt", "?key1=value1&key2=value2"),
101105
]:
102106
await display.goto(loc.pathname + loc.search)
103-
poll_location.until_equals(loc)
107+
await poll_location.until_equals(loc)

tests/test_web/test_module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def ShowSimpleButton():
110110

111111
button = await display.page.wait_for_selector("#my-button")
112112
await button.click()
113-
poll(lambda: is_clicked.current).until_is(True)
113+
await poll(lambda: is_clicked.current).until_is(True)
114114

115115

116116
def test_module_from_file_source_conflict(tmp_path):

tests/test_widgets.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,12 @@ def SomeComponent():
118118

119119
poll_value = poll(lambda: value.current)
120120

121-
poll_value.until_equals("hello")
121+
await poll_value.until_equals("hello")
122122

123123
await input_2.focus()
124124
await input_2.type(" world", delay=20)
125125

126-
poll_value.until_equals("hello world")
126+
await poll_value.until_equals("hello world")
127127

128128

129129
async def test_use_linked_inputs_on_change_with_cast(display: DisplayFixture):
@@ -145,12 +145,12 @@ def SomeComponent():
145145

146146
poll_value = poll(lambda: value.current)
147147

148-
poll_value.until_equals(1)
148+
await poll_value.until_equals(1)
149149

150150
await input_2.focus()
151151
await input_2.type("2")
152152

153-
poll_value.until_equals(12)
153+
await poll_value.until_equals(12)
154154

155155

156156
async def test_use_linked_inputs_ignore_empty(display: DisplayFixture):
@@ -174,13 +174,13 @@ def SomeComponent():
174174

175175
poll_value = poll(lambda: value.current)
176176

177-
poll_value.until_equals("1")
177+
await poll_value.until_equals("1")
178178

179179
await input_2.focus()
180180
await input_2.press("Backspace")
181181

182-
poll_value.until_equals("1")
182+
await poll_value.until_equals("1")
183183

184184
await input_2.type("2")
185185

186-
poll_value.until_equals("2")
186+
await poll_value.until_equals("2")

tests/tooling/loop.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import asyncio
22
import sys
33
import threading
4+
import time
45
from asyncio import wait_for
56
from contextlib import contextmanager
67
from typing import Iterator
78

9+
from idom.config import IDOM_TESTING_DEFAULT_TIMEOUT
810
from idom.testing import poll
911

1012

@@ -37,7 +39,14 @@ def open_event_loop(as_current: bool = True) -> Iterator[asyncio.AbstractEventLo
3739
finally:
3840
if as_current:
3941
asyncio.set_event_loop(None)
40-
poll(loop.is_running).until_is(False)
42+
start = time.time()
43+
while loop.is_running():
44+
if (time.time() - start) > IDOM_TESTING_DEFAULT_TIMEOUT.current:
45+
raise TimeoutError(
46+
"Failed to stop loop after "
47+
f"{IDOM_TESTING_DEFAULT_TIMEOUT.current} seconds"
48+
)
49+
time.sleep(0.1)
4150
loop.close()
4251

4352

0 commit comments

Comments
 (0)