-
-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathhooks.py
149 lines (118 loc) · 4.75 KB
/
hooks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from __future__ import annotations
import asyncio
from typing import Any, Awaitable, Callable, DefaultDict, Sequence, Union, cast
from channels.db import database_sync_to_async as _database_sync_to_async
from django.db.models.base import Model
from django.db.models.query import QuerySet
from idom import use_callback, use_ref
from idom.backend.types import Location
from idom.core.hooks import Context, create_context, use_context, use_effect, use_state
from django_idom.types import IdomWebsocket, Mutation, Query, _Params, _Result
database_sync_to_async = cast(
Callable[..., Callable[..., Awaitable[Any]]],
_database_sync_to_async,
)
WebsocketContext: Context[Union[IdomWebsocket, None]] = create_context(None)
_REFETCH_CALLBACKS: DefaultDict[
Callable[..., Any], set[Callable[[], None]]
] = DefaultDict(set)
def use_location() -> Location:
"""Get the current route as a string"""
# TODO: Use the browser's current page, rather than the WS route
scope = use_scope()
search = scope["query_string"].decode()
return Location(scope["path"], f"?{search}" if search else "")
def use_scope() -> dict[str, Any]:
"""Get the current ASGI scope dictionary"""
return use_websocket().scope
def use_websocket() -> IdomWebsocket:
"""Get the current IdomWebsocket object"""
websocket = use_context(WebsocketContext)
if websocket is None:
raise RuntimeError("No websocket. Are you running with a Django server?")
return websocket
def use_query(
query: Callable[_Params, Union[_Result, None]],
*args: _Params.args,
**kwargs: _Params.kwargs,
) -> Query[Union[_Result, None]]:
query_ref = use_ref(query)
if query_ref.current is not query:
raise ValueError(f"Query function changed from {query_ref.current} to {query}.")
should_execute, set_should_execute = use_state(True)
data, set_data = use_state(cast(Union[_Result, None], None))
loading, set_loading = use_state(True)
error, set_error = use_state(cast(Union[Exception, None], None))
@use_callback
def refetch() -> None:
set_should_execute(True)
set_loading(True)
set_error(None)
@use_effect(dependencies=[])
def add_refetch_callback() -> Callable[[], None]:
# By tracking callbacks globally, any usage of the query function will be re-run
# if the user has told a mutation to refetch it.
_REFETCH_CALLBACKS[query].add(refetch)
return lambda: _REFETCH_CALLBACKS[query].remove(refetch)
@use_effect(dependencies=None)
@database_sync_to_async
def execute_query() -> None:
if not should_execute:
return
try:
new_data = query(*args, **kwargs)
_fetch_deferred(new_data)
except Exception as e:
set_data(None)
set_loading(False)
set_error(e)
return
finally:
set_should_execute(False)
set_data(new_data)
set_loading(False)
set_error(None)
return Query(data, loading, error, refetch)
def use_mutation(
mutate: Callable[_Params, Union[bool, None]],
refetch: Union[Callable[..., Any], Sequence[Callable[..., Any]]],
) -> Mutation[_Params]:
loading, set_loading = use_state(False)
error, set_error = use_state(cast(Union[Exception, None], None))
@use_callback
def call(*args: _Params.args, **kwargs: _Params.kwargs) -> None:
set_loading(True)
@database_sync_to_async
def execute_mutation() -> None:
try:
should_refetch = mutate(*args, **kwargs)
except Exception as e:
set_loading(False)
set_error(e)
else:
set_loading(False)
set_error(None)
if should_refetch is not False:
for query in (refetch,) if callable(refetch) else refetch:
for callback in _REFETCH_CALLBACKS.get(query) or ():
callback()
asyncio.ensure_future(execute_mutation())
@use_callback
def reset() -> None:
set_loading(False)
set_error(None)
return Mutation(call, loading, error, reset)
def _fetch_deferred(data: Any) -> None:
# https://github.com/typeddjango/django-stubs/issues/704
if isinstance(data, QuerySet): # type: ignore[misc]
for model in data:
_fetch_deferred_model_fields(model)
elif isinstance(data, Model):
_fetch_deferred_model_fields(data)
else:
raise ValueError(f"Expected a Model or QuerySet, got {data!r}")
def _fetch_deferred_model_fields(model: Any) -> None:
for field in model.get_deferred_fields():
value = getattr(model, field)
if isinstance(value, Model):
_fetch_deferred_model_fields(value)