From 8f6c8ec01da8ad4845d0a5493a33e5eca84bfcac Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Mon, 4 Mar 2024 17:44:54 -0800 Subject: [PATCH 01/11] Fix static type hints for component decorator (#2) * Fix component decorator eating static type hints --- src/py/reactpy/reactpy/core/component.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py index f825aac71..e3c6b068d 100644 --- a/src/py/reactpy/reactpy/core/component.py +++ b/src/py/reactpy/reactpy/core/component.py @@ -2,14 +2,17 @@ import inspect from functools import wraps -from typing import Any, Callable +from typing import Any, Callable, TypeVar, ParamSpec from reactpy.core.types import ComponentType, VdomDict +T = TypeVar("T", bound=ComponentType | VdomDict | str | None) +P = ParamSpec("P") + def component( - function: Callable[..., ComponentType | VdomDict | str | None] -) -> Callable[..., Component]: + function: Callable[P, T], +) -> Callable[P, Component]: """A decorator for defining a new component. Parameters: @@ -25,7 +28,7 @@ def component( raise TypeError(msg) @wraps(function) - def constructor(*args: Any, key: Any | None = None, **kwargs: Any) -> Component: + def constructor(*args: P.args, key: Any | None = None, **kwargs: P.kwargs) -> Component: return Component(function, key, args, kwargs, sig) return constructor From efa547f4606a6d5f8447a901ac12f880a7e34f28 Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Mon, 4 Mar 2024 18:09:10 -0800 Subject: [PATCH 02/11] Connection resume (#1) Add reconnection and client state side state capabilities --- src/js/package-lock.json | 4 +- .../@reactpy/client/src/components.tsx | 27 +- .../packages/@reactpy/client/src/messages.ts | 9 +- .../@reactpy/client/src/reactpy-client.ts | 336 +++++++++++++++--- src/py/reactpy/pyproject.toml | 3 + src/py/reactpy/reactpy/__init__.py | 10 +- src/py/reactpy/reactpy/backend/hooks.py | 17 +- src/py/reactpy/reactpy/backend/sanic.py | 61 ++-- .../reactpy/reactpy/core/_life_cycle_hook.py | 59 ++- src/py/reactpy/reactpy/core/_thread_local.py | 21 -- src/py/reactpy/reactpy/core/component.py | 51 ++- src/py/reactpy/reactpy/core/hooks.py | 103 +++++- src/py/reactpy/reactpy/core/layout.py | 243 +++++++++++-- src/py/reactpy/reactpy/core/serve.py | 179 ++++++++-- src/py/reactpy/reactpy/core/state_recovery.py | 276 ++++++++++++++ src/py/reactpy/reactpy/core/types.py | 37 ++ src/py/reactpy/reactpy/testing/common.py | 4 +- src/py/reactpy/reactpy/utils.py | 20 +- src/py/reactpy/tests/test_core/test_layout.py | 6 +- src/py/reactpy/tests/tooling/hooks.py | 4 +- 20 files changed, 1245 insertions(+), 225 deletions(-) delete mode 100644 src/py/reactpy/reactpy/core/_thread_local.py create mode 100644 src/py/reactpy/reactpy/core/state_recovery.py diff --git a/src/js/package-lock.json b/src/js/package-lock.json index 2edfdd260..2904bba0e 100644 --- a/src/js/package-lock.json +++ b/src/js/package-lock.json @@ -28,7 +28,7 @@ "@types/react": "^17.0", "@types/react-dom": "^17.0", "typescript": "^4.9.5", - "vite": "^3.1.8" + "vite": "^3.2.7" } }, "app/node_modules/@reactpy/client": { @@ -3955,7 +3955,7 @@ "@types/react-dom": "^17.0", "preact": "^10.7.0", "typescript": "^4.9.5", - "vite": "^3.1.8" + "vite": "^3.2.7" }, "dependencies": { "@reactpy/client": { diff --git a/src/js/packages/@reactpy/client/src/components.tsx b/src/js/packages/@reactpy/client/src/components.tsx index 728c4cec7..fd23d3a8a 100644 --- a/src/js/packages/@reactpy/client/src/components.tsx +++ b/src/js/packages/@reactpy/client/src/components.tsx @@ -29,12 +29,13 @@ export function Layout(props: { client: ReactPyClient }): JSX.Element { useEffect( () => - props.client.onMessage("layout-update", ({ path, model }) => { + props.client.onMessage("layout-update", ({ path, model, state_vars }) => { if (path === "") { Object.assign(currentModel, model); } else { setJsonPointer(currentModel, path, model); } + props.client.updateStateVars(state_vars); forceUpdate(); }), [currentModel, props.client], @@ -125,23 +126,15 @@ function ScriptElement({ model }: { model: ReactPyVdom }) { (value): value is string => typeof value == "string", )[0]; - let scriptElement: HTMLScriptElement; - if (model.attributes) { - scriptElement = document.createElement("script"); - for (const [k, v] of Object.entries(model.attributes)) { - scriptElement.setAttribute(k, v); - } - if (scriptContent) { - scriptElement.appendChild(document.createTextNode(scriptContent)); - } - ref.current.appendChild(scriptElement); - } else if (scriptContent) { - const scriptResult = eval(scriptContent); - if (typeof scriptResult == "function") { - return scriptResult(); - } + const scriptElement: HTMLScriptElement = document.createElement("script"); + for (const [k, v] of Object.entries(model.attributes || {})) { + scriptElement.setAttribute(k, v); + } + if (scriptContent) { + scriptElement.appendChild(document.createTextNode(scriptContent)); } - }, [model.key, ref.current]); + ref.current.appendChild(scriptElement); + }, [model.key]); return
; } diff --git a/src/js/packages/@reactpy/client/src/messages.ts b/src/js/packages/@reactpy/client/src/messages.ts index 34001dcb0..5fbfc24bf 100644 --- a/src/js/packages/@reactpy/client/src/messages.ts +++ b/src/js/packages/@reactpy/client/src/messages.ts @@ -12,6 +12,11 @@ export type LayoutEventMessage = { data: any; }; -export type IncomingMessage = LayoutUpdateMessage; -export type OutgoingMessage = LayoutEventMessage; +export type ReconnectingCheckMessage = { + type: "reconnecting-check"; + value: string; +} + +export type IncomingMessage = LayoutUpdateMessage | ReconnectingCheckMessage; +export type OutgoingMessage = LayoutEventMessage | ReconnectingCheckMessage; export type Message = IncomingMessage | OutgoingMessage; diff --git a/src/js/packages/@reactpy/client/src/reactpy-client.ts b/src/js/packages/@reactpy/client/src/reactpy-client.ts index 6f37b55a1..c5018e9a5 100644 --- a/src/js/packages/@reactpy/client/src/reactpy-client.ts +++ b/src/js/packages/@reactpy/client/src/reactpy-client.ts @@ -29,16 +29,26 @@ export interface ReactPyClient { * @returns A promise that resolves to the module. */ loadModule(moduleName: string): Promise; + + /** + * Update state vars from the server for reconnections + * @param givenStateVars State vars to store + */ + updateStateVars(givenStateVars: object): void; } export abstract class BaseReactPyClient implements ReactPyClient { private readonly handlers: { [key: string]: ((message: any) => void)[] } = {}; protected readonly ready: Promise; private resolveReady: (value: undefined) => void; + protected stateVars: object; + protected debugMessages: boolean; constructor() { - this.resolveReady = () => {}; + this.resolveReady = () => { }; this.ready = new Promise((resolve) => (this.resolveReady = resolve)); + this.stateVars = {}; + this.debugMessages = false; } onMessage(type: string, handler: (message: any) => void): () => void { @@ -52,6 +62,10 @@ export abstract class BaseReactPyClient implements ReactPyClient { abstract sendMessage(message: any): void; abstract loadModule(moduleName: string): Promise; + updateStateVars(givenStateVars: object): void { + this.stateVars = Object.assign(this.stateVars, givenStateVars); + } + /** * Handle an incoming message. * @@ -65,6 +79,10 @@ export abstract class BaseReactPyClient implements ReactPyClient { return; } + if (this.debugMessages) { + logger.log("Got message", message); + } + const messageHandlers: ((m: any) => void)[] | undefined = this.handlers[message.type]; if (!messageHandlers) { @@ -79,6 +97,9 @@ export abstract class BaseReactPyClient implements ReactPyClient { export type SimpleReactPyClientProps = { serverLocation?: LocationProps; reconnectOptions?: ReconnectProps; + idleDisconnectTimeSeconds?: number; + connectionTimeout?: number; + debugMessages?: boolean; }; /** @@ -117,14 +138,37 @@ type ReconnectProps = { maxRetries?: number; backoffRate?: number; intervalJitter?: number; + reconnectingCallback?: Function; + reconnectedCallback?: Function; +}; + +enum messageTypes { + isReady = "is-ready", + reconnectingCheck = "reconnecting-check", + clientState = "client-state", + stateUpdate = "state-update" }; export class SimpleReactPyClient extends BaseReactPyClient - implements ReactPyClient -{ + implements ReactPyClient { private readonly urls: ServerUrls; - private readonly socket: { current?: WebSocket }; + private socket!: { current?: WebSocket }; + private idleDisconnectTimeMillis: number; + private lastMessageTime: number; + private reconnectOptions: ReconnectProps | undefined; + private messageQueue: any[] = []; + private socketLoopIntervalId?: number | null; + private idleCheckIntervalId?: number | null; + private sleeping: boolean; + private isReconnecting: boolean; + private isReady: boolean; + private salt: string; + private shouldReconnect: boolean; + private connectionTimeout: number; + private reconnectingCallback: Function; + private reconnectedCallback: Function; + private didReconnectingCallback: boolean; constructor(props: SimpleReactPyClientProps) { super(); @@ -136,17 +180,238 @@ export class SimpleReactPyClient query: document.location.search, }, ); + this.idleDisconnectTimeMillis = (props.idleDisconnectTimeSeconds || 240) * 1000; + this.connectionTimeout = props.connectionTimeout || 5000; + this.lastMessageTime = Date.now() + this.reconnectOptions = props.reconnectOptions + this.debugMessages = props.debugMessages || false; + this.sleeping = false; + this.isReconnecting = false; + this.isReady = false + this.salt = ""; + this.shouldReconnect = false; + this.didReconnectingCallback = false; + this.reconnectingCallback = props.reconnectOptions?.reconnectingCallback || this.showReconnectingGrayout; + this.reconnectedCallback = props.reconnectOptions?.reconnectedCallback || this.hideReconnectingGrayout; + + this.onMessage(messageTypes.reconnectingCheck, () => { this.indicateReconnect() }) + this.onMessage(messageTypes.isReady, (msg) => { this.isReady = true; this.salt = msg.salt; }); + this.onMessage(messageTypes.clientState, () => { this.sendClientState() }); + this.onMessage(messageTypes.stateUpdate, (msg) => { this.updateClientState(msg.state_vars) }); + + this.reconnect() + + const reconnectOnUserAction = (ev: any) => { + if (!this.isReady && !this.isReconnecting) { + this.reconnect(); + } + } + + window.addEventListener('mousemove', reconnectOnUserAction); + window.addEventListener('scroll', reconnectOnUserAction); + } + + showReconnectingGrayout() { + const overlay = document.createElement('div'); + overlay.id = 'reactpy-reconnect-overlay'; + + const pipeContainer = document.createElement('div'); + const pipeSymbol = document.createElement('div'); + pipeSymbol.textContent = '|'; // Set the pipe symbol + + overlay.style.cssText = ` + position: fixed; + top: 0; + left: 0; + width: 100%; + height: 100%; + background-color: rgba(0,0,0,0.5); + display: flex; + justify-content: center; + align-items: center; + z-index: 1000; + `; + + pipeContainer.style.cssText = ` + display: flex; + justify-content: center; + align-items: center; + width: 40px; + height: 40px; + `; + + pipeSymbol.style.cssText = ` + font-size: 24px; + color: #FFF; + display: inline-block; + width: 100%; + height: 100%; + text-align: center; + transform-origin: center; + `; + + pipeContainer.appendChild(pipeSymbol); + overlay.appendChild(pipeContainer); + document.body.appendChild(overlay); + + // Create and start the spin animation + let angle = 0; + function spin() { + angle = (angle + 2) % 360; // Adjust rotation speed as needed + pipeSymbol.style.transform = `rotate(${angle}deg)`; + requestAnimationFrame(spin); + } + spin(); + } + + hideReconnectingGrayout() { + const overlay = document.getElementById('reactpy-reconnect-overlay'); + if (overlay && overlay.parentNode) { + overlay.parentNode.removeChild(overlay); + } + } - this.socket = createReconnectingWebSocket({ - readyPromise: this.ready, - url: this.urls.stream, - onMessage: async ({ data }) => this.handleIncoming(JSON.parse(data)), - ...props.reconnectOptions, + indicateReconnect(): void { + const isReconnecting = this.isReconnecting ? "yes" : "no"; + this.sendMessage({ "type": messageTypes.reconnectingCheck, "value": isReconnecting }, true) + } + + sendClientState(): void { + if (!this.socket) + return; + this.transmitMessage({ + "type": "client-state", + "value": this.stateVars, + "salt": this.salt }); } - sendMessage(message: any): void { - this.socket.current?.send(JSON.stringify(message)); + updateClientState(stateVars: object): void { + if (!this.socket) + return; + this.updateStateVars(stateVars) + } + + socketLoop(): void { + if (!this.socket) + return; + if (this.messageQueue.length > 0 && this.isReady && this.socket.current && this.socket.current.readyState === WebSocket.OPEN) { + const message = this.messageQueue.shift(); // Remove the first message from the queue + this.transmitMessage(message); + } + } + + transmitMessage(message: any): void { + if (this.socket && this.socket.current) { + if (this.debugMessages) { + logger.log("Sending message", message); + } + this.socket.current.send(JSON.stringify(message)); + } + } + + idleTimeoutCheck(): void { + if (!this.socket) + return; + if (Date.now() - this.lastMessageTime > this.idleDisconnectTimeMillis) { + if (this.socket.current && this.socket.current.readyState === WebSocket.OPEN) { + logger.warn("Closing socket connection due to idle activity"); + this.sleeping = true; + this.socket.current.close(); + } + } + } + + reconnect(onOpen?: () => void, interval: number = 750, connectionAttemptsRemaining: number = 20, lastAttempt: number = 0): void { + const intervalJitter = this.reconnectOptions?.intervalJitter || 0.5; + const backoffRate = this.reconnectOptions?.backoffRate || 1.2; + const maxInterval = this.reconnectOptions?.maxInterval || 20000; + const maxRetries = this.reconnectOptions?.maxRetries || 20; + + + if (connectionAttemptsRemaining <= 0) { + logger.warn("Giving up on reconnecting (hit retry limit)"); + this.shouldReconnect = false; + this.isReconnecting = false; + return + } + + if (this.shouldReconnect) { + // already reconnecting + return; + } + lastAttempt = lastAttempt || Date.now(); + this.shouldReconnect = true; + + window.setTimeout(() => { + if (!this.didReconnectingCallback && this.reconnectingCallback) { + this.didReconnectingCallback = true; + this.reconnectingCallback(); + } + + if (maxRetries < connectionAttemptsRemaining) + connectionAttemptsRemaining = maxRetries; + + this.socket = createWebSocket({ + connectionTimeout: this.connectionTimeout, + readyPromise: this.ready, + url: this.urls.stream, + onOpen: () => { + lastAttempt = Date.now(); + if (this.reconnectedCallback) { + this.reconnectedCallback(); + this.didReconnectingCallback = false; + } + if (onOpen) + onOpen(); + }, + onClose: () => { + // reset retry interval + if (Date.now() - lastAttempt > maxInterval * 2) { + interval = 750; + connectionAttemptsRemaining = maxRetries; + } + lastAttempt = Date.now() + this.shouldReconnect = false; + this.isReconnecting = true; + this.isReady = false; + if (this.socketLoopIntervalId) + clearInterval(this.socketLoopIntervalId); + if (this.idleCheckIntervalId) + clearInterval(this.idleCheckIntervalId); + if (!this.sleeping) { + const thisInterval = nextInterval(addJitter(interval, intervalJitter), backoffRate, maxInterval); + const newRetriesRemaining = connectionAttemptsRemaining - 1; + logger.log( + `reconnecting in ${(thisInterval / 1000).toPrecision(4)} seconds... (${newRetriesRemaining} retries remaining)`, + ); + this.reconnect(onOpen, thisInterval, newRetriesRemaining, lastAttempt); + } + }, + onMessage: async ({ data }) => { this.lastMessageTime = Date.now(); this.handleIncoming(JSON.parse(data)) }, + ...this.reconnectOptions, + }); + this.socketLoopIntervalId = window.setInterval(() => { this.socketLoop() }, 30); + this.idleCheckIntervalId = window.setInterval(() => { this.idleTimeoutCheck() }, 10000); + + }, interval) + } + + ensureConnected(): void { + if (this.socket.current?.readyState == WebSocket.CLOSED) { + this.reconnect(); + } + } + + sendMessage(message: any, immediate: boolean = false): void { + if (immediate) { + this.transmitMessage(message); + } else { + this.messageQueue.push(message); + } + this.lastMessageTime = Date.now() + this.sleeping = false; + this.ensureConnected(); } loadModule(moduleName: string): Promise { @@ -173,25 +438,16 @@ function getServerUrls(props: LocationProps): ServerUrls { return { base, modules, assets, stream }; } -function createReconnectingWebSocket( +function createWebSocket( props: { url: string; readyPromise: Promise; + connectionTimeout: number; onOpen?: () => void; onMessage: (message: MessageEvent) => void; onClose?: () => void; - } & ReconnectProps, + }, ) { - const { - maxInterval = 60000, - maxRetries = 50, - backoffRate = 1.1, - intervalJitter = 0.1, - } = props; - - const startInterval = 750; - let retries = 0; - let interval = startInterval; const closed = false; let everConnected = false; const socket: { current?: WebSocket } = {}; @@ -201,11 +457,18 @@ function createReconnectingWebSocket( return; } socket.current = new WebSocket(props.url); + const connectionTimeout = props.connectionTimeout; // Timeout in milliseconds + + const timeoutId = setTimeout(() => { + if (socket.current && socket.current.readyState !== WebSocket.OPEN) { + socket.current.close(); + console.error('Connection attempt timed out'); + } + }, connectionTimeout); socket.current.onopen = () => { + clearTimeout(timeoutId); everConnected = true; logger.log("client connected"); - interval = startInterval; - retries = 0; if (props.onOpen) { props.onOpen(); } @@ -214,25 +477,12 @@ function createReconnectingWebSocket( socket.current.onclose = () => { if (!everConnected) { logger.log("failed to connect"); - return; + } else { + logger.log("client disconnected"); } - - logger.log("client disconnected"); if (props.onClose) { props.onClose(); } - - if (retries >= maxRetries) { - return; - } - - const thisInterval = addJitter(interval, intervalJitter); - logger.log( - `reconnecting in ${(thisInterval / 1000).toPrecision(4)} seconds...`, - ); - setTimeout(connect, thisInterval); - interval = nextInterval(interval, backoffRate, maxInterval); - retries++; }; }; @@ -247,16 +497,16 @@ function nextInterval( maxInterval: number, ): number { return Math.min( - currentInterval * + (currentInterval * // increase interval by backoff rate - backoffRate, + backoffRate), // don't exceed max interval maxInterval, ); } function addJitter(interval: number, jitter: number): number { - return interval + (Math.random() * jitter * interval * 2 - jitter * interval); + return interval + (Math.random() * jitter * interval); } function rtrim(text: string, trim: string): string { diff --git a/src/py/reactpy/pyproject.toml b/src/py/reactpy/pyproject.toml index 309248507..c16e5f065 100644 --- a/src/py/reactpy/pyproject.toml +++ b/src/py/reactpy/pyproject.toml @@ -35,6 +35,9 @@ dependencies = [ "colorlog >=6", "asgiref >=3", "lxml >=4", + "pyotp", + "orjson", + "more-itertools", ] [project.optional-dependencies] all = ["reactpy[starlette,sanic,fastapi,flask,tornado,testing]"] diff --git a/src/py/reactpy/reactpy/__init__.py b/src/py/reactpy/reactpy/__init__.py index 49e357441..32b4712a1 100644 --- a/src/py/reactpy/reactpy/__init__.py +++ b/src/py/reactpy/reactpy/__init__.py @@ -1,10 +1,16 @@ from reactpy import backend, config, html, logging, sample, svg, types, web, widgets -from reactpy.backend.hooks import use_connection, use_location, use_scope +from reactpy.backend.hooks import ( + use_connection, + use_location, + use_reconnect_effect, + use_scope, +) from reactpy.backend.utils import run from reactpy.core import hooks from reactpy.core.component import component from reactpy.core.events import event from reactpy.core.hooks import ( + ReconnectingOnly, create_context, use_callback, use_context, @@ -33,6 +39,7 @@ "html", "Layout", "logging", + "ReconnectingOnly", "Ref", "run", "sample", @@ -46,6 +53,7 @@ "use_effect", "use_location", "use_memo", + "use_reconnect_effect", "use_reducer", "use_ref", "use_scope", diff --git a/src/py/reactpy/reactpy/backend/hooks.py b/src/py/reactpy/reactpy/backend/hooks.py index ee4ce1b5c..3424b9b86 100644 --- a/src/py/reactpy/reactpy/backend/hooks.py +++ b/src/py/reactpy/reactpy/backend/hooks.py @@ -1,10 +1,16 @@ from __future__ import annotations from collections.abc import MutableMapping -from typing import Any +from typing import Any, Callable from reactpy.backend.types import Connection, Location -from reactpy.core.hooks import create_context, use_context +from reactpy.core.hooks import ( + ReconnectingOnly, + _EffectApplyFunc, + create_context, + use_context, + use_effect, +) from reactpy.core.types import Context # backend implementations should establish this context at the root of an app @@ -28,3 +34,10 @@ def use_scope() -> MutableMapping[str, Any]: def use_location() -> Location: """Get the current :class:`~reactpy.backend.types.Connection`'s location.""" return use_connection().location + + +def use_reconnect_effect( + function: _EffectApplyFunc | None = None, +) -> Callable[[_EffectApplyFunc], None] | None: + """Apply an effect only on reconnection""" + return use_effect(function, ReconnectingOnly()) diff --git a/src/py/reactpy/reactpy/backend/sanic.py b/src/py/reactpy/reactpy/backend/sanic.py index 76eb0423e..bad90b072 100644 --- a/src/py/reactpy/reactpy/backend/sanic.py +++ b/src/py/reactpy/reactpy/backend/sanic.py @@ -7,6 +7,7 @@ from typing import Any from urllib import parse as urllib_parse from uuid import uuid4 +import orjson from sanic import Blueprint, Sanic, request, response from sanic.config import Config @@ -24,11 +25,15 @@ safe_web_modules_dir_path, serve_with_uvicorn, ) -from reactpy.backend.hooks import ConnectionContext from reactpy.backend.hooks import use_connection as _use_connection from reactpy.backend.types import Connection, Location -from reactpy.core.layout import Layout -from reactpy.core.serve import RecvCoroutine, SendCoroutine, Stop, serve_layout +from reactpy.core.serve import ( + RecvCoroutine, + SendCoroutine, + Stop, + WebsocketServer, +) +from reactpy.core.state_recovery import StateRecoveryManager from reactpy.core.types import RootComponentConstructor logger = logging.getLogger(__name__) @@ -51,6 +56,7 @@ def configure( app: Sanic[Any, Any], component: RootComponentConstructor, options: Options | None = None, + state_recovery_manager: StateRecoveryManager | None = None, ) -> None: """Configure an application instance to display the given component""" options = options or Options() @@ -59,7 +65,9 @@ def configure( api_bp = Blueprint(f"reactpy_api_{id(app)}", url_prefix=str(PATH_PREFIX)) _setup_common_routes(api_bp, spa_bp, options) - _setup_single_view_dispatcher_route(api_bp, component, options) + _setup_single_view_dispatcher_route( + api_bp, component, options, state_recovery_manager + ) app.blueprint([spa_bp, api_bp]) @@ -159,6 +167,7 @@ def _setup_single_view_dispatcher_route( api_blueprint: Blueprint, constructor: RootComponentConstructor, options: Options, + state_recovery_manager: StateRecoveryManager | None, ) -> None: async def model_stream( request: request.Request[Any, Any], @@ -171,27 +180,29 @@ async def model_stream( logger.warning("No scope. Sanic may not be running with an ASGI server") send, recv = _make_send_recv_callbacks(socket) - await serve_layout( - Layout( - ConnectionContext( - constructor(), - value=Connection( - scope=scope, - location=Location( - pathname=f"/{path[len(options.url_prefix):]}", - search=( - f"?{request.query_string}" - if request.query_string - else "" - ), - ), - carrier=_SanicCarrier(request, socket), - ), - ) + + server = WebsocketServer(send, recv, state_recovery_manager) + await server.handle_connection( + Connection( + scope=scope, + location=Location( + pathname=f"/{path[len(options.url_prefix):]}", + search=(f"?{request.query_string}" if request.query_string else ""), + ), + carrier=_SanicCarrier(request, socket), ), - send, - recv, + constructor, ) + # await serve_layout( + # Layout( + # ConnectionContext( + # constructor(), + # value=, + # ) + # ), + # send, + # recv, + # ) api_blueprint.add_websocket_route( model_stream, @@ -209,13 +220,13 @@ def _make_send_recv_callbacks( socket: WebSocketConnection, ) -> tuple[SendCoroutine, RecvCoroutine]: async def sock_send(value: Any) -> None: - await socket.send(json.dumps(value)) + await socket.send(orjson.dumps(value).decode("utf-8")) async def sock_recv() -> Any: data = await socket.recv() if data is None: raise Stop() - return json.loads(data) + return orjson.loads(data) return sock_send, sock_recv diff --git a/src/py/reactpy/reactpy/core/_life_cycle_hook.py b/src/py/reactpy/reactpy/core/_life_cycle_hook.py index 88d3386a8..c0e3dfc6f 100644 --- a/src/py/reactpy/reactpy/core/_life_cycle_hook.py +++ b/src/py/reactpy/reactpy/core/_life_cycle_hook.py @@ -2,12 +2,16 @@ import logging from asyncio import Event, Task, create_task, gather -from typing import Any, Callable, Protocol, TypeVar +from contextvars import ContextVar, Token +from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar from anyio import Semaphore -from reactpy.core._thread_local import ThreadLocal from reactpy.core.types import ComponentType, Context, ContextProviderType +from reactpy.utils import Ref + +if TYPE_CHECKING: + from reactpy.core.hooks import _CurrentState T = TypeVar("T") @@ -18,12 +22,27 @@ async def __call__(self, stop: Event) -> None: ... logger = logging.getLogger(__name__) -_HOOK_STATE: ThreadLocal[list[LifeCycleHook]] = ThreadLocal(list) +_hook_state = ContextVar("_hook_state") + + +def create_hook_state(initial: list | None = None) -> Token[list]: + return _hook_state.set(initial or []) + + +def clear_hook_state(token: Token[list]) -> None: + hook_stack = _hook_state.get() + if hook_stack: + logger.warning("clear_hook_state: Hook stack was not empty") + _hook_state.reset(token) + + +def get_hook_state() -> list[LifeCycleHook]: + return _hook_state.get() -def current_hook() -> LifeCycleHook: +def get_current_hook() -> LifeCycleHook: """Get the current :class:`LifeCycleHook`""" - hook_stack = _HOOK_STATE.get() + hook_stack = _hook_state.get() if not hook_stack: msg = "No life cycle hook is active. Are you rendering in a layout?" raise RuntimeError(msg) @@ -117,6 +136,10 @@ async def my_effect(stop_event): "_scheduled_render", "_state", "component", + "reconnecting", + "client_state", + "_updated_states", + "_previous_states", ) component: ComponentType @@ -124,17 +147,35 @@ async def my_effect(stop_event): def __init__( self, schedule_render: Callable[[], None], + reconnecting: Ref, + client_state: dict[str, Any], + updated_states: dict[str, Any], + previous_states: dict[str, Any], ) -> None: self._context_providers: dict[Context[Any], ContextProviderType[Any]] = {} self._schedule_render_callback = schedule_render self._scheduled_render = False self._rendered_atleast_once = False self._current_state_index = 0 - self._state: tuple[Any, ...] = () + self._state: list = [] self._effect_funcs: list[EffectFunc] = [] self._effect_tasks: list[Task[None]] = [] self._effect_stops: list[Event] = [] self._render_access = Semaphore(1) # ensure only one render at a time + self.reconnecting = reconnecting + self.client_state = client_state or {} + self._updated_states = updated_states + self._previous_states = previous_states + + def add_state_update(self, updated_state: _CurrentState | Ref) -> None: + if ( + updated_state.key + and self._previous_states.get( + updated_state.key, "__missing_lifecycle_key_value__" + ) + is not updated_state.value + ): + self._updated_states[updated_state.key] = updated_state.value def schedule_render(self) -> None: if self._scheduled_render: @@ -157,7 +198,7 @@ def use_state(self, function: Callable[[], T]) -> T: if not self._rendered_atleast_once: # since we're not initialized yet we're just appending state result = function() - self._state += (result,) + self._state.append(result) else: # once finalized we iterate over each succesively used piece of state result = self._state[self._current_state_index] @@ -232,7 +273,7 @@ def set_current(self) -> None: This method is called by a layout before entering the render method of this hook's associated component. """ - hook_stack = _HOOK_STATE.get() + hook_stack = get_hook_state() if hook_stack: parent = hook_stack[-1] self._context_providers.update(parent._context_providers) @@ -240,5 +281,5 @@ def set_current(self) -> None: def unset_current(self) -> None: """Unset this hook as the active hook in this thread""" - if _HOOK_STATE.get().pop() is not self: + if get_hook_state().pop() is not self: raise RuntimeError("Hook stack is in an invalid state") # nocov diff --git a/src/py/reactpy/reactpy/core/_thread_local.py b/src/py/reactpy/reactpy/core/_thread_local.py deleted file mode 100644 index b3d6a14b0..000000000 --- a/src/py/reactpy/reactpy/core/_thread_local.py +++ /dev/null @@ -1,21 +0,0 @@ -from threading import Thread, current_thread -from typing import Callable, Generic, TypeVar -from weakref import WeakKeyDictionary - -_StateType = TypeVar("_StateType") - - -class ThreadLocal(Generic[_StateType]): - """Utility for managing per-thread state information""" - - def __init__(self, default: Callable[[], _StateType]): - self._default = default - self._state: WeakKeyDictionary[Thread, _StateType] = WeakKeyDictionary() - - def get(self) -> _StateType: - thread = current_thread() - if thread not in self._state: - state = self._state[thread] = self._default() - else: - state = self._state[thread] - return state diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py index e3c6b068d..9d4955546 100644 --- a/src/py/reactpy/reactpy/core/component.py +++ b/src/py/reactpy/reactpy/core/component.py @@ -2,7 +2,7 @@ import inspect from functools import wraps -from typing import Any, Callable, TypeVar, ParamSpec +from typing import Any, Callable, ParamSpec, TypeVar from reactpy.core.types import ComponentType, VdomDict @@ -11,33 +11,52 @@ def component( - function: Callable[P, T], + function: Callable[P, T] | None = None, + *, + priority: int = 0, ) -> Callable[P, Component]: """A decorator for defining a new component. Parameters: - function: The component's :meth:`reactpy.core.proto.ComponentType.render` function. + priority: The rendering priority. Lower numbers are higher priority. """ - sig = inspect.signature(function) - if "key" in sig.parameters and sig.parameters["key"].kind in ( - inspect.Parameter.KEYWORD_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ): - msg = f"Component render function {function} uses reserved parameter 'key'" - raise TypeError(msg) + def _component(function: Callable[P, T]) -> Callable[P, Component]: + sig = inspect.signature(function) - @wraps(function) - def constructor(*args: P.args, key: Any | None = None, **kwargs: P.kwargs) -> Component: - return Component(function, key, args, kwargs, sig) + if "key" in sig.parameters and sig.parameters["key"].kind in ( + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + msg = f"Component render function {function} uses reserved parameter 'key'" + raise TypeError(msg) - return constructor + @wraps(function) + def constructor( + *args: P.args, key: Any | None = None, **kwargs: P.kwargs + ) -> Component: + return Component(function, key, args, kwargs, sig, priority) + + return constructor + + if function: + return _component(function) + return _component class Component: """An object for rending component models.""" - __slots__ = "__weakref__", "_func", "_args", "_kwargs", "_sig", "key", "type" + __slots__ = ( + "__weakref__", + "_func", + "_args", + "_kwargs", + "_sig", + "key", + "type", + "priority", + ) def __init__( self, @@ -46,12 +65,14 @@ def __init__( args: tuple[Any, ...], kwargs: dict[str, Any], sig: inspect.Signature, + priority: int = 0, ) -> None: self.key = key self.type = function self._args = args self._kwargs = kwargs self._sig = sig + self.priority = priority def render(self) -> ComponentType | VdomDict | str | None: return self.type(*self._args, **self._kwargs) diff --git a/src/py/reactpy/reactpy/core/hooks.py b/src/py/reactpy/reactpy/core/hooks.py index 640cbf14c..72f7c2c4e 100644 --- a/src/py/reactpy/reactpy/core/hooks.py +++ b/src/py/reactpy/reactpy/core/hooks.py @@ -1,7 +1,11 @@ from __future__ import annotations import asyncio +from functools import lru_cache +import hashlib +import sys from collections.abc import Coroutine, Sequence +from hashlib import md5 from logging import getLogger from types import FunctionType from typing import ( @@ -18,7 +22,8 @@ from typing_extensions import TypeAlias from reactpy.config import REACTPY_DEBUG_MODE -from reactpy.core._life_cycle_hook import current_hook +from reactpy.core._life_cycle_hook import get_current_hook +from reactpy.core.state_recovery import StateRecoveryFailureError from reactpy.core.types import Context, Key, State, VdomDict from reactpy.utils import Ref @@ -38,6 +43,13 @@ logger = getLogger(__name__) + +class ReconnectingOnly(list): + """ + Used to indicate that a hook should only be used during reconnection + """ + + _Type = TypeVar("_Type") @@ -49,7 +61,9 @@ def use_state(initial_value: Callable[[], _Type]) -> State[_Type]: ... def use_state(initial_value: _Type) -> State[_Type]: ... -def use_state(initial_value: _Type | Callable[[], _Type]) -> State[_Type]: +def use_state( + initial_value: _Type | Callable[[], _Type], *, server_only: bool = False +) -> State[_Type]: """See the full :ref:`Use State` docs for details Parameters: @@ -61,23 +75,62 @@ def use_state(initial_value: _Type | Callable[[], _Type]) -> State[_Type]: Returns: A tuple containing the current state and a function to update it. """ - current_state = _use_const(lambda: _CurrentState(initial_value)) + if server_only: + key = None + else: + hook = get_current_hook() + caller_info = get_caller_info() + key = get_state_key(caller_info) + if hook.reconnecting.current: + try: + initial_value = hook.client_state[key] + except KeyError as err: + raise StateRecoveryFailureError( + f"Missing expected key {key} on client" + ) from err + current_state = _use_const(lambda: _CurrentState(key, initial_value)) return State(current_state.value, current_state.dispatch) +def get_caller_info(): + # Get the current stack frame and then the frame above it + caller_frame = sys._getframe(2) + for i in range(50): + render_frame = sys._getframe(4 + i) + patch_path = render_frame.f_locals.get("patch_path_for_state") + if patch_path is not None: + break + # Extract the relevant information: file path and line number and hash it + return f"{caller_frame.f_code.co_filename} {caller_frame.f_lineno} {patch_path}" + + +__DEBUG_CALLER_INFO_TO_STATE_KEY = {} + + +@lru_cache(8192) +def get_state_key(caller_info: str) -> str: + result = hashlib.sha256(caller_info.encode("utf8")).hexdigest()[:20] + if __debug__: + __DEBUG_CALLER_INFO_TO_STATE_KEY[result] = caller_info + return result + + class _CurrentState(Generic[_Type]): - __slots__ = "value", "dispatch" + __slots__ = "key", "value", "dispatch" def __init__( self, + key: str | None, initial_value: _Type | Callable[[], _Type], ) -> None: + self.key = key if callable(initial_value): self.value = initial_value() else: self.value = initial_value - hook = current_hook() + hook = get_current_hook() + hook.add_state_update(self) def dispatch(new: _Type | Callable[[_Type], _Type]) -> None: if callable(new): @@ -86,6 +139,7 @@ def dispatch(new: _Type | Callable[[_Type], _Type]) -> None: next_value = new if not strictly_equal(next_value, self.value): self.value = next_value + hook.add_state_update(self) hook.schedule_render() self.dispatch = dispatch @@ -131,11 +185,17 @@ def use_effect( Returns: If not function is provided, a decorator. Otherwise ``None``. """ - hook = current_hook() - - dependencies = _try_to_infer_closure_values(function, dependencies) memoize = use_memo(dependencies=dependencies) last_clean_callback: Ref[_EffectCleanFunc | None] = use_ref(None) + hook = get_current_hook() + if hook.reconnecting.current: + if not isinstance(dependencies, ReconnectingOnly): + return + dependencies = None + else: + if isinstance(dependencies, ReconnectingOnly): + return + dependencies = _try_to_infer_closure_values(function, dependencies) def add_effect(function: _EffectApplyFunc) -> None: if not asyncio.iscoroutinefunction(function): @@ -204,7 +264,7 @@ def use_debug_value( if REACTPY_DEBUG_MODE.current and old.current != new: old.current = new - logger.debug(f"{current_hook().component} {new}") + logger.debug(f"{get_current_hook().component} {new}") def create_context(default_value: _Type) -> Context[_Type]: @@ -232,7 +292,7 @@ def use_context(context: Context[_Type]) -> _Type: See the full :ref:`Use Context` docs for more information. """ - hook = current_hook() + hook = get_current_hook() provider = hook.get_context_provider(context) if provider is None: @@ -255,14 +315,16 @@ def __init__( value: _Type, key: Key | None, type: Context[_Type], + priority: int = -1, ) -> None: self.children = children self.key = key self.type = type self.value = value + self.priority = priority def render(self) -> VdomDict: - current_hook().set_context_provider(self) + get_current_hook().set_context_provider(self) return {"tagName": "", "children": self.children} def __repr__(self) -> str: @@ -447,7 +509,7 @@ def empty(self) -> bool: return False -def use_ref(initial_value: _Type) -> Ref[_Type]: +def use_ref(initial_value: _Type, server_only: bool = True) -> Ref[_Type]: """See the full :ref:`Use State` docs for details Parameters: @@ -456,11 +518,24 @@ def use_ref(initial_value: _Type) -> Ref[_Type]: Returns: A :class:`Ref` object. """ - return _use_const(lambda: Ref(initial_value)) + if server_only: + key = None + else: + hook = get_current_hook() + caller_info = get_caller_info() + key = get_state_key(caller_info) + if hook.reconnecting.current: + try: + initial_value = hook.client_state[key] + except KeyError as err: + raise StateRecoveryFailureError( + f"Missing expected key {key} on client" + ) from err + return _use_const(lambda: Ref(initial_value, key)) def _use_const(function: Callable[[], _Type]) -> _Type: - return current_hook().use_state(function) + return get_current_hook().use_state(function) def _try_to_infer_closure_values( diff --git a/src/py/reactpy/reactpy/core/layout.py b/src/py/reactpy/reactpy/core/layout.py index 70bdbbbff..8db467e30 100644 --- a/src/py/reactpy/reactpy/core/layout.py +++ b/src/py/reactpy/reactpy/core/layout.py @@ -1,9 +1,12 @@ from __future__ import annotations import abc +import asyncio +import copy from asyncio import ( FIRST_COMPLETED, CancelledError, + PriorityQueue, Queue, Task, create_task, @@ -16,7 +19,9 @@ from logging import getLogger from typing import ( Any, + Awaitable, Callable, + Coroutine, Generic, NamedTuple, NewType, @@ -34,13 +39,22 @@ REACTPY_CHECK_VDOM_SPEC, REACTPY_DEBUG_MODE, ) -from reactpy.core._life_cycle_hook import LifeCycleHook +from reactpy.core._life_cycle_hook import ( + LifeCycleHook, + clear_hook_state, + create_hook_state, + get_hook_state, +) +from reactpy.core.component import Component +from reactpy.core.hooks import _ContextProvider +from reactpy.core.state_recovery import StateRecoverySerializer from reactpy.core.types import ( ComponentType, EventHandlerDict, Key, LayoutEventMessage, LayoutUpdateMessage, + StateUpdateMessage, VdomChild, VdomDict, VdomJson, @@ -62,34 +76,64 @@ class Layout: "_render_tasks_ready", "_root_life_cycle_state_id", "_model_states_by_life_cycle_state_id", + "reconnecting", + "client_state", + "_state_recovery_serializer", + "_state_var_lock", + "_hook_state_token", + "_previous_states", ) if not hasattr(abc.ABC, "__weakref__"): # nocov __slots__ += ("__weakref__",) - def __init__(self, root: ComponentType) -> None: + def __init__( + self, + root: ComponentType, + ) -> None: super().__init__() - if not isinstance(root, ComponentType): - msg = f"Expected a ComponentType, not {type(root)!r}." - raise TypeError(msg) + # slow + # if not isinstance(root, ComponentType): + # msg = f"Expected a ComponentType, not {type(root)!r}." + # raise TypeError(msg) self.root = root + self.reconnecting = Ref(False) + self._state_recovery_serializer = None + self.client_state = {} + self._previous_states = {} + + def set_recovery_serializer(self, serializer: StateRecoverySerializer) -> None: + self._state_recovery_serializer = serializer async def __aenter__(self) -> Layout: + return await self.start() + + async def start(self) -> Layout: + self._hook_state_token = create_hook_state() + # create attributes here to avoid access before entering context manager self._event_handlers: EventHandlerDict = {} self._render_tasks: set[Task[LayoutUpdateMessage]] = set() self._render_tasks_ready: Semaphore = Semaphore(0) self._rendering_queue: _ThreadSafeQueue[_LifeCycleStateId] = _ThreadSafeQueue() - root_model_state = _new_root_model_state(self.root, self._schedule_render_task) + root_model_state = _new_root_model_state( + self.root, + self._schedule_render_task, + self.reconnecting, + self.client_state, + self._previous_states, + ) self._root_life_cycle_state_id = root_id = root_model_state.life_cycle_state.id self._model_states_by_life_cycle_state_id = {root_id: root_model_state} - self._schedule_render_task(root_id) return self async def __aexit__(self, *exc: Any) -> None: + return await self.finish() + + async def finish(self) -> None: root_csid = self._root_life_cycle_state_id root_model_state = self._model_states_by_life_cycle_state_id[root_csid] @@ -108,6 +152,14 @@ async def __aexit__(self, *exc: Any) -> None: del self._root_life_cycle_state_id del self._model_states_by_life_cycle_state_id + clear_hook_state(self._hook_state_token) + + def start_rendering(self) -> None: + self._schedule_render_task(self._root_life_cycle_state_id) + + def start_rendering_for_reconnect(self) -> None: + self._rendering_queue.put(self._root_life_cycle_state_id) + async def deliver(self, event: LayoutEventMessage) -> None: """Dispatch an event to the targeted handler""" # It is possible for an element in the frontend to produce an event @@ -122,7 +174,7 @@ async def deliver(self, event: LayoutEventMessage) -> None: except Exception: logger.exception(f"Failed to execute event handler {handler}") else: - logger.info( + logger.warning( f"Ignored event - handler {event['target']!r} " "does not exist or its component unmounted" ) @@ -133,6 +185,32 @@ async def render(self) -> LayoutUpdateMessage: else: # nocov return await self._serial_render() + async def render_until_queue_empty(self) -> None: + model_state_id = await self._rendering_queue.get() + while True: + try: + model_state = self._model_states_by_life_cycle_state_id[model_state_id] + except KeyError: + logger.debug( + "Did not render component with model state ID " + f"{model_state_id!r} - component already unmounted" + ) + else: + await self._create_layout_update(model_state, get_hook_state()) + # this might seem counterintuitive. What's happening is that events can get kicked off + # and currently there's no (obvious) visibility on if we're waiting for them to finish + # so this will wait up to 0.15 * 5 = 750 ms to see if any renders come in before + # declaring it done. In the future, it would be better to just track the pending events + for _ in range(5): + try: + model_state_id = await self._rendering_queue.get_nowait() + except asyncio.QueueEmpty: + await asyncio.sleep(0.15) # make sure + else: + break + else: + return + async def _serial_render(self) -> LayoutUpdateMessage: # nocov """Await the next available render. This will block until a component is updated""" while True: @@ -145,7 +223,7 @@ async def _serial_render(self) -> LayoutUpdateMessage: # nocov f"{model_state_id!r} - component already unmounted" ) else: - return await self._create_layout_update(model_state) + return await self._create_layout_update(model_state, get_hook_state()) async def _concurrent_render(self) -> LayoutUpdateMessage: """Await the next available render. This will block until a component is updated""" @@ -156,8 +234,9 @@ async def _concurrent_render(self) -> LayoutUpdateMessage: return update_task.result() async def _create_layout_update( - self, old_state: _ModelState + self, old_state: _ModelState, incoming_hook_state: list ) -> LayoutUpdateMessage: + token = create_hook_state(copy.copy(incoming_hook_state)) new_state = _copy_component_model_state(old_state) component = new_state.life_cycle_state.component @@ -167,11 +246,21 @@ async def _create_layout_update( if REACTPY_CHECK_VDOM_SPEC.current: validate_vdom_json(new_state.model.current) - return { - "type": "layout-update", - "path": new_state.patch_path, - "model": new_state.model.current, - } + updated_states = new_state.life_cycle_state.hook._updated_states + state_vars = ( + (await self._state_recovery_serializer.serialize_state_vars(updated_states)) + if self._state_recovery_serializer + else {} + ) + self._previous_states.update(updated_states) + updated_states.clear() + clear_hook_state(token) + return LayoutUpdateMessage( + type="layout-update", + path=new_state.patch_path, + model=new_state.model.current, + state_vars=state_vars, + ) async def _render_component( self, @@ -188,6 +277,7 @@ async def _render_component( await life_cycle_hook.affect_component_will_render(component) exit_stack.push_async_callback(life_cycle_hook.affect_layout_did_render) try: + patch_path_for_state = new_state.patch_path # type: ignore # noqa raw_model = component.render() # wrap the model in a fragment (i.e. tagName="") to ensure components have # a separate node in the model state tree. This could be removed if this @@ -278,7 +368,11 @@ def _render_model_attributes( if event in old_state.targets_by_event: target = old_state.targets_by_event[event] else: - target = uuid4().hex if handler.target is None else handler.target + target = ( + new_state.patch_path + event + if handler.target is None + else handler.target + ) new_state.targets_by_event[event] = target self._event_handlers[target] = handler model_event_handlers[event] = { @@ -299,7 +393,11 @@ def _render_model_event_handlers_without_old_state( model_event_handlers = new_state.model.current["eventHandlers"] = {} for event, handler in handlers_by_event.items(): - target = uuid4().hex if handler.target is None else handler.target + target = ( + new_state.patch_path + event + if handler.target is None + else handler.target + ) new_state.targets_by_event[event] = target self._event_handlers[target] = handler model_event_handlers[event] = { @@ -385,6 +483,8 @@ async def _render_model_children( key, child, self._schedule_render_task, + self.reconnecting, + self.client_state, ) elif old_child_state.is_component_state and ( old_child_state.life_cycle_state.component.type != child.type @@ -397,6 +497,8 @@ async def _render_model_children( key, child, self._schedule_render_task, + self.reconnecting, + self.client_state, ) else: new_child_state = _update_component_model_state( @@ -405,6 +507,8 @@ async def _render_model_children( index, child, self._schedule_render_task, + self.reconnecting, + self.client_state, ) await self._render_component( exit_stack, old_child_state, new_child_state, child @@ -439,7 +543,13 @@ async def _render_model_children_without_old_state( new_state.children_by_key[key] = child_state elif child_type is _COMPONENT_TYPE: child_state = _make_component_model_state( - new_state, index, key, child, self._schedule_render_task + new_state, + index, + key, + child, + self._schedule_render_task, + self.reconnecting, + self.client_state, ) await self._render_component(exit_stack, None, child_state, child) else: @@ -455,14 +565,19 @@ async def _unmount_model_states(self, old_states: list[_ModelState]) -> None: if model_state.is_component_state: life_cycle_state = model_state.life_cycle_state - del self._model_states_by_life_cycle_state_id[life_cycle_state.id] - await life_cycle_state.hook.affect_component_will_unmount() + try: + del self._model_states_by_life_cycle_state_id[life_cycle_state.id] + await life_cycle_state.hook.affect_component_will_unmount() + except KeyError: + pass # sideeffect of reusing model states to_unmount.extend(model_state.children_by_key.values()) - def _schedule_render_task(self, lcs_id: _LifeCycleStateId) -> None: + def _schedule_render_task( + self, lcs_id: _LifeCycleStateId, priority: int = 0 + ) -> None: if not REACTPY_ASYNC_RENDERING.current: - self._rendering_queue.put(lcs_id) + self._rendering_queue.put(lcs_id, priority) return None try: model_state = self._model_states_by_life_cycle_state_id[lcs_id] @@ -480,7 +595,11 @@ def __repr__(self) -> str: def _new_root_model_state( - component: ComponentType, schedule_render: Callable[[_LifeCycleStateId], None] + component: ComponentType, + schedule_render: Callable[[_LifeCycleStateId], None], + reconnecting: bool, + client_state: dict[str, Any], + previous_states: dict[str, Any], ) -> _ModelState: return _ModelState( parent=None, @@ -490,7 +609,9 @@ def _new_root_model_state( patch_path="", children_by_key={}, targets_by_event={}, - life_cycle_state=_make_life_cycle_state(component, schedule_render), + life_cycle_state=_make_life_cycle_state( + component, schedule_render, reconnecting, client_state, {}, previous_states + ), ) @@ -499,8 +620,11 @@ def _make_component_model_state( index: int, key: Any, component: ComponentType, - schedule_render: Callable[[_LifeCycleStateId], None], + schedule_render: Callable[[_LifeCycleStateId, int], None], + reconnecting: bool, + client_state: dict[str, Any], ) -> _ModelState: + hook = (parent.life_cycle_state or parent.parent_life_cycle_state).hook return _ModelState( parent=parent, index=index, @@ -509,7 +633,14 @@ def _make_component_model_state( patch_path=f"{parent.patch_path}/children/{index}", children_by_key={}, targets_by_event={}, - life_cycle_state=_make_life_cycle_state(component, schedule_render), + life_cycle_state=_make_life_cycle_state( + component, + schedule_render, + reconnecting, + client_state, + hook._updated_states, + hook._previous_states, + ), ) @@ -537,8 +668,11 @@ def _update_component_model_state( new_parent: _ModelState, new_index: int, new_component: ComponentType, - schedule_render: Callable[[_LifeCycleStateId], None], + schedule_render: Callable[[_LifeCycleStateId, int], None], + reconnecting: bool, + client_state: dict[str, Any], ) -> _ModelState: + hook = (new_parent.life_cycle_state or new_parent.parent_life_cycle_state).hook return _ModelState( parent=new_parent, index=new_index, @@ -550,7 +684,14 @@ def _update_component_model_state( life_cycle_state=( _update_life_cycle_state(old_model_state.life_cycle_state, new_component) if old_model_state.is_component_state - else _make_life_cycle_state(new_component, schedule_render) + else _make_life_cycle_state( + new_component, + schedule_render, + reconnecting, + client_state, + hook._updated_states, + hook._previous_states, + ) ), ) @@ -568,6 +709,9 @@ def _make_element_model_state( patch_path=f"{parent.patch_path}/children/{index}", children_by_key={}, targets_by_event={}, + parent_life_cycle_state=( + parent.life_cycle_state or parent.parent_life_cycle_state + ), ) @@ -584,6 +728,9 @@ def _update_element_model_state( patch_path=old_model_state.patch_path, children_by_key={}, targets_by_event={}, + parent_life_cycle_state=( + new_parent.life_cycle_state or new_parent.parent_life_cycle_state + ), ) @@ -598,6 +745,7 @@ class _ModelState: "index", "key", "life_cycle_state", + "parent_life_cycle_state", "model", "patch_path", "targets_by_event", @@ -613,6 +761,7 @@ def __init__( children_by_key: dict[Key, _ModelState], targets_by_event: dict[str, str], life_cycle_state: _LifeCycleState | None = None, + parent_life_cycle_state: _LifeCycleState | None = None, ): self.index = index """The index of the element amongst its siblings""" @@ -639,13 +788,14 @@ def __init__( self._parent_ref = weakref(parent) """The parent model state""" - if life_cycle_state is not None: - self.life_cycle_state = life_cycle_state - """The state for the element's component (if it has one)""" + self.life_cycle_state = life_cycle_state + """The state for the element's component (if it has one)""" + + self.parent_life_cycle_state = parent_life_cycle_state @property def is_component_state(self) -> bool: - return hasattr(self, "life_cycle_state") + return self.life_cycle_state is not None @property def parent(self) -> _ModelState: @@ -663,12 +813,22 @@ def __repr__(self) -> str: # nocov def _make_life_cycle_state( component: ComponentType, - schedule_render: Callable[[_LifeCycleStateId], None], + schedule_render: Callable[[_LifeCycleStateId, int], None], + reconnecting: bool, + client_state: dict[str, Any], + updated_states: dict[str, Any], + previous_states: dict[str, Any], ) -> _LifeCycleState: life_cycle_state_id = _LifeCycleStateId(uuid4().hex) return _LifeCycleState( life_cycle_state_id, - LifeCycleHook(lambda: schedule_render(life_cycle_state_id)), + LifeCycleHook( + lambda: schedule_render(life_cycle_state_id, component.priority), + reconnecting=reconnecting, + client_state=client_state, + updated_states=updated_states, + previous_states=previous_states, + ), component, ) @@ -707,16 +867,21 @@ class _LifeCycleState(NamedTuple): class _ThreadSafeQueue(Generic[_Type]): def __init__(self) -> None: self._loop = get_running_loop() - self._queue: Queue[_Type] = Queue() + self._queue: PriorityQueue[_Type] = PriorityQueue() self._pending: set[_Type] = set() - def put(self, value: _Type) -> None: + def put(self, value: _Type, priority: int = 0) -> None: if value not in self._pending: self._pending.add(value) - self._loop.call_soon_threadsafe(self._queue.put_nowait, value) + self._loop.call_soon_threadsafe(self._queue.put_nowait, (priority, value)) async def get(self) -> _Type: - value = await self._queue.get() + priority, value = await self._queue.get() + self._pending.remove(value) + return value + + async def get_nowait(self) -> _Type: + priority, value = self._queue.get_nowait() self._pending.remove(value) return value @@ -729,7 +894,7 @@ def _get_children_info(children: list[VdomChild]) -> Sequence[_ChildInfo]: elif isinstance(child, dict): child_type = _DICT_TYPE key = child.get("key") - elif isinstance(child, ComponentType): + elif isinstance(child, (Component, _ContextProvider)): child_type = _COMPONENT_TYPE key = child.key else: diff --git a/src/py/reactpy/reactpy/core/serve.py b/src/py/reactpy/reactpy/core/serve.py index 3a540af59..b7832b565 100644 --- a/src/py/reactpy/reactpy/core/serve.py +++ b/src/py/reactpy/reactpy/core/serve.py @@ -1,5 +1,7 @@ from __future__ import annotations +import random +import string from collections.abc import Awaitable from logging import getLogger from typing import Callable @@ -8,16 +10,39 @@ from anyio import create_task_group from anyio.abc import TaskGroup +from reactpy.backend.hooks import ConnectionContext +from reactpy.backend.types import Connection from reactpy.config import REACTPY_DEBUG_MODE -from reactpy.core.types import LayoutEventMessage, LayoutType, LayoutUpdateMessage +from reactpy.core._life_cycle_hook import clear_hook_state, create_hook_state +from reactpy.core.layout import Layout +from reactpy.core.state_recovery import StateRecoveryFailureError, StateRecoveryManager +from reactpy.core.types import ( + ClientStateMessage, + IsReadyMessage, + LayoutEventMessage, + LayoutType, + LayoutUpdateMessage, + ReconnectingCheckMessage, + RootComponentConstructor, +) logger = getLogger(__name__) -SendCoroutine = Callable[[LayoutUpdateMessage], Awaitable[None]] +SendCoroutine = Callable[ + [ + LayoutUpdateMessage + | ReconnectingCheckMessage + | IsReadyMessage + | ClientStateMessage + ], + Awaitable[None], +] """Send model patches given by a dispatcher""" -RecvCoroutine = Callable[[], Awaitable[LayoutEventMessage]] +RecvCoroutine = Callable[ + [], Awaitable[LayoutEventMessage | ReconnectingCheckMessage | ClientStateMessage] +] """Called by a dispatcher to return a :class:`reactpy.core.layout.LayoutEventMessage` The event will then trigger an :class:`reactpy.core.proto.EventHandlerType` in a layout. @@ -40,44 +65,146 @@ async def serve_layout( recv: RecvCoroutine, ) -> None: """Run a dispatch loop for a single view instance""" - async with layout: - try: - async with create_task_group() as task_group: - task_group.start_soon(_single_outgoing_loop, layout, send) - task_group.start_soon(_single_incoming_loop, task_group, layout, recv) - except Stop: # nocov - warn( - "The Stop exception is deprecated and will be removed in a future version", - UserWarning, - stacklevel=1, - ) - logger.info(f"Stopped serving {layout}") + + try: + async with create_task_group() as task_group: + task_group.start_soon(_single_outgoing_loop, layout, send) + task_group.start_soon(_single_incoming_loop, task_group, layout, recv, send) + except Stop: # nocov + warn( + "The Stop exception is deprecated and will be removed in a future version", + UserWarning, + stacklevel=1, + ) + logger.info(f"Stopped serving {layout}") async def _single_outgoing_loop( layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage], send: SendCoroutine ) -> None: while True: - update = await layout.render() + token = create_hook_state() try: - await send(update) - except Exception: # nocov - if not REACTPY_DEBUG_MODE.current: - msg = ( - "Failed to send update. More info may be available " - "if you enabling debug mode by setting " - "`reactpy.config.REACTPY_DEBUG_MODE.current = True`." - ) - logger.error(msg) - raise + update = await layout.render() + try: + await send(update) + except Exception: # nocov + if not REACTPY_DEBUG_MODE.current: + msg = ( + "Failed to send update. More info may be available " + "if you enabling debug mode by setting " + "`reactpy.config.REACTPY_DEBUG_MODE.current = True`." + ) + logger.error(msg) + raise + finally: + clear_hook_state(token) async def _single_incoming_loop( task_group: TaskGroup, layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage], recv: RecvCoroutine, + send: SendCoroutine, ) -> None: while True: # We need to fire and forget here so that we avoid waiting on the completion # of this event handler before receiving and running the next one. task_group.start_soon(layout.deliver, await recv()) + + +class WebsocketServer: + def __init__( + self, + send: SendCoroutine, + recv: RecvCoroutine, + state_recovery_manager: StateRecoveryManager | None = None, + ) -> None: + self._send = send + self._recv = recv + self._state_recovery_manager = state_recovery_manager + self._salt: str | None = None + + async def handle_connection( + self, connection: Connection, constructor: RootComponentConstructor + ): + layout = Layout( + ConnectionContext( + constructor(), + value=connection, + ), + ) + async with layout: + await self._handshake(layout) + # salt may be set to client's old salt during handshake + if self._state_recovery_manager: + layout.set_recovery_serializer( + self._state_recovery_manager.create_serializer(self._salt) + ) + await serve_layout( + layout, + self._send, + self._recv, + ) + + async def _handshake(self, layout: Layout) -> None: + await self._send(ReconnectingCheckMessage(type="reconnecting-check")) + result = await self._recv() + self._salt = "".join(random.choices(string.ascii_letters + string.digits, k=8)) + if result["type"] == "reconnecting-check": + if result["value"] == "yes": + if self._state_recovery_manager is None: + logger.warning( + "Reconnection detected, but no state recovery manager provided" + ) + layout.start_rendering() + else: + logger.info("Handshake: Doing state rebuild for reconnection") + self._salt = await self._do_state_rebuild_for_reconnection(layout) + logger.info("Handshake: Completed doing state rebuild") + else: + logger.info("Handshake: new connection") + layout.start_rendering() + else: + logger.warning( + f"Unexpected type when expecting reconnecting-check: {result['type']}" + ) + await self._indicate_ready(), + + async def _indicate_ready(self) -> None: + await self._send(IsReadyMessage(type="is-ready", salt=self._salt)) + + async def _do_state_rebuild_for_reconnection(self, layout: Layout) -> str: + salt = self._salt + await self._send(ClientStateMessage(type="client-state")) + client_state_msg = await self._recv() + if client_state_msg["type"] != "client-state": + logger.warning( + f"Unexpected type when expecting client-state: {client_state_msg['type']}" + ) + return + state_vars = client_state_msg["value"] + try: + serializer = self._state_recovery_manager.create_serializer( + client_state_msg["salt"] + ) + client_state = serializer.deserialize_client_state(state_vars) + layout.reconnecting.set_current(True) + layout.client_state = client_state + except StateRecoveryFailureError: + logger.exception("State recovery failed") + layout.reconnecting.set_current(False) + layout.client_state = {} + else: + salt = client_state_msg["salt"] + try: + layout.start_rendering_for_reconnect() + await layout.render_until_queue_empty() + except StateRecoveryFailureError: + logger.warning("Client state non-recoverable. Starting fresh") + await layout.finish() + await layout.start() + layout.start_rendering() + layout.reconnecting.set_current(False) + layout.client_state = {} + return salt diff --git a/src/py/reactpy/reactpy/core/state_recovery.py b/src/py/reactpy/reactpy/core/state_recovery.py new file mode 100644 index 000000000..38be33786 --- /dev/null +++ b/src/py/reactpy/reactpy/core/state_recovery.py @@ -0,0 +1,276 @@ +import asyncio +import base64 +import datetime +import hashlib +import time +from collections.abc import Iterable +from decimal import Decimal +from logging import getLogger +from pathlib import Path +from typing import Any, Callable +from uuid import UUID + +import orjson +import pyotp +from more_itertools import chunked + +logger = getLogger(__name__) + + +class StateRecoveryFailureError(Exception): + """ + Raised when state recovery fails. + """ + + +class StateRecoveryManager: + def __init__( + self, + serializable_types: Iterable[type], + pepper: str, + otp_key: str | None = None, + otp_interval: int = (4 * 60 * 60), + otp_digits: int = 10, # 10 is the max allowed + otp_max_age: int = (48 * 60 * 60), + # OTP code is actually three codes, in the past and future concatenated + otp_mixer: float = (365 * 24 * 60 * 60 * 3), + max_num_state_objects: int = 512, + max_object_length: int = 40000, + default_serializer: Callable[[Any], bytes] | None = None, + deserializer_map: dict[type, Callable[[Any], Any]] | None = None, + ) -> None: + self._pepper = pepper + self._max_num_state_objects = max_num_state_objects + self._max_object_length = max_object_length + self._otp_key = base64.b32encode( + (otp_key or self._discover_otp_key()).encode("utf-8") + ) + self._totp = pyotp.TOTP(self._otp_key, digits=otp_digits, interval=otp_interval) + self._otp_max_age = otp_max_age + self._default_serializer = default_serializer + self._deserializer_map = deserializer_map or {} + self._otp_mixer = otp_mixer + + self._map_objects_to_ids( + [ + *list(serializable_types), + Decimal, + datetime.datetime, + datetime.date, + datetime.time, + ] + ) + + def _map_objects_to_ids(self, serializable_types: Iterable[type]) -> dict: + self._object_to_type_id = {} + self._type_id_to_object = {} + for idx, typ in enumerate( + (None, bool, str, int, float, list, tuple, UUID, *serializable_types) + ): + idx_as_bytes = str(idx).encode("utf-8") + self._object_to_type_id[typ] = idx_as_bytes + self._type_id_to_object[idx_as_bytes] = typ + + def _discover_otp_key(self) -> str: + """ + Generate an OTP key by looking at the parent directory of where + ReactPy is installed and taking down the names and creation times + of everything in there. + """ + hasher = hashlib.sha256() + parent_dir_of_root = Path(__file__).parent.parent.parent + for thing in parent_dir_of_root.iterdir(): + hasher.update((thing.name + str(thing.stat().st_ctime)).encode("utf-8")) + return hasher.hexdigest() + + def create_serializer( + self, salt: str, target_time: float | None = None + ) -> "StateRecoverySerializer": + return StateRecoverySerializer( + totp=self._totp, + target_time=target_time, + otp_max_age=self._otp_max_age, + otp_mixer=self._otp_mixer, + pepper=self._pepper, + salt=salt, + object_to_type_id=self._object_to_type_id, + type_id_to_object=self._type_id_to_object, + max_object_length=self._max_object_length, + max_num_state_objects=self._max_num_state_objects, + default_serializer=self._default_serializer, + deserializer_map=self._deserializer_map, + ) + + +class StateRecoverySerializer: + + def __init__( + self, + totp: pyotp.TOTP, + target_time: float | None, + otp_max_age: int, + otp_mixer: float, + pepper: str, + salt: str, + object_to_type_id: dict[Any, bytes], + type_id_to_object: dict[bytes, Any], + max_object_length: int, + max_num_state_objects: int, + default_serializer: Callable[[Any], bytes] | None = None, + deserializer_map: dict[type, Callable[[Any], Any]] | None = None, + ) -> None: + self._totp = totp + self._otp_mixer = otp_mixer + target_time = target_time or time.time() + self._target_time = target_time + otp_code = self._get_otp_code(target_time) + self._otp_max_age = otp_max_age + self._otp_code = otp_code.encode("utf-8") + self._pepper = pepper.encode("utf-8") + self._salt = salt.encode("utf-8") + self._object_to_type_id = object_to_type_id + self._type_id_to_object = type_id_to_object + self._max_object_length = max_object_length + self._max_num_state_objects = max_num_state_objects + self._default_serializer = default_serializer + self._deserializer_map = deserializer_map or {} + + def _get_otp_code(self, target_time: float) -> str: + at = self._totp.at + return f"{at(target_time)}{at(target_time - self._otp_mixer)}{at(target_time + self._otp_mixer)}" + + async def serialize_state_vars( + self, state_vars: dict[str, Any] + ) -> dict[str, tuple[str, str, str]]: + if len(state_vars) > self._max_num_state_objects: + logger.warning( + f"State is too large ({len(state_vars)}). State will not be sent" + ) + return {} + result = {} + for chunk in chunked(state_vars.items(), 50): + for key, value in chunk: + result[key] = self._serialize(key, value) + await asyncio.sleep(0) # relinquish CPU + return result + + def _serialize(self, key: str, obj: object) -> tuple[str, str, str]: + type_id = b"1" # bool + if obj is None: + return "0", "", "" + match obj: + case True: + result = b"true" + case False: + result = b"false" + case _: + obj_type = type(obj) + if obj_type in (list, tuple): + if len(obj) != 0: + obj_type = type(obj[0]) + for t in obj_type.__mro__: + type_id = self._object_to_type_id.get(t) + if type_id: + break + else: + raise ValueError( + f"Objects of type {obj_type} was not part of serializable_types" + ) + result = self._serialize_object(obj) + if len(result) > self._max_object_length: + raise ValueError( + f"Serialized object {obj} is too long (length: {len(result)})" + ) + signature = self._sign_serialization(key, type_id, result) + return ( + type_id.decode("utf-8"), + base64.urlsafe_b64encode(result).decode("utf-8"), + signature, + ) + + def deserialize_client_state( + self, state_vars: dict[str, tuple[str, str, str]] + ) -> None: + return { + key: self._deserialize(key, type_id.encode("utf-8"), data, signature) + for key, (type_id, data, signature) in state_vars.items() + } + + def _deserialize( + self, key: str, type_id: bytes, data: bytes, signature: str + ) -> Any: + if type_id == b"0": + return None + try: + typ = self._type_id_to_object[type_id] + except KeyError as err: + raise StateRecoveryFailureError(f"Unknown type id {type_id}") from err + + result = base64.urlsafe_b64decode(data) + expected_signature = self._sign_serialization(key, type_id, result) + if expected_signature != signature: + if not self._try_future_code(key, type_id, result, signature): + if not self._try_older_codes_and_see_if_one_checks_out( + key, type_id, result, signature + ): + raise StateRecoveryFailureError( + f"Signature mismatch for type id {type_id}" + ) + return self._deserialize_object(typ, result) + + def _try_future_code( + self, key: str, type_id: bytes, data: bytes, signature: str + ) -> bool: + future_time = self._target_time + self._totp.interval + otp_code = self._get_otp_code(future_time).encode("utf-8") + return self._sign_serialization(key, type_id, data, otp_code) == signature + + def _try_older_codes_and_see_if_one_checks_out( + self, key: str, type_id: bytes, data: bytes, signature: str + ) -> bool: + past_time = self._target_time + for _ in range(100): + past_time -= self._totp.interval + otp_code = self._get_otp_code(past_time).encode("utf-8") + if self._sign_serialization(key, type_id, data, otp_code) == signature: + return True + if past_time < self._target_time - self._otp_max_age: + return False + raise RuntimeError("Too many iterations: _try_older_codes_and_see_if_one_checks_out") + + def _sign_serialization( + self, key: str, type_id: bytes, data: bytes, otp_code: bytes | None = None + ) -> str: + hasher = hashlib.sha256() + hasher.update(type_id) + hasher.update(data) + hasher.update(self._pepper) + hasher.update(otp_code or self._otp_code) + hasher.update(self._salt) + hasher.update(key.encode("utf-8")) + return hasher.hexdigest() + + def _serialize_object(self, obj: Any) -> bytes: + return orjson.dumps(obj, default=self._default_serializer) + + def _do_deserialize( + self, typ: type, result: Any, custom_deserializer: Callable | None + ) -> Any: + if custom_deserializer: + return custom_deserializer(result) + if isinstance(result, str): + return typ(result) + if isinstance(result, dict): + return typ(**result) + return result + + def _deserialize_object(self, typ: Any, data: bytes) -> Any: + if typ is None and not data: + return None + result = orjson.loads(data) + custom_deserializer = self._deserializer_map.get(typ) + if type(result) in (list, tuple): + return [ + self._do_deserialize(typ, item, custom_deserializer) for item in result + ] + return self._do_deserialize(typ, result, custom_deserializer) diff --git a/src/py/reactpy/reactpy/core/types.py b/src/py/reactpy/reactpy/core/types.py index b451be30a..a4be74f61 100644 --- a/src/py/reactpy/reactpy/core/types.py +++ b/src/py/reactpy/reactpy/core/types.py @@ -57,6 +57,7 @@ class ComponentType(Protocol): This is used to see if two component instances share the same definition. """ + priority: int def render(self) -> VdomDict | ComponentType | str | None: """Render the component's view model.""" @@ -213,6 +214,42 @@ class LayoutUpdateMessage(TypedDict): """JSON Pointer path to the model element being updated""" model: VdomJson """The model to assign at the given JSON Pointer path""" + state_vars: dict[str, Any] + + +class StateUpdateMessage(TypedDict): + """A message describing an update to state variables""" + + type: Literal["state-update"] + """The type of message""" + state_vars: dict[str, Any] + + +class ReconnectingCheckMessage(TypedDict): + """A message describing an update to a layout""" + + type: Literal["reconnecting-check"] + """The type of message""" + value: Literal["yes", "no"] + + +class ClientStateMessage(TypedDict): + """A message requesting the current state of the client""" + + type: Literal["client-state"] + """The type of message""" + value: dict[str, Any] + """The client state""" + salt: str + """The salt provided to the user""" + + +class IsReadyMessage(TypedDict): + """Indicate server is ready for client events""" + + type: Literal["is-ready"] + + salt: str class LayoutEventMessage(TypedDict): diff --git a/src/py/reactpy/reactpy/testing/common.py b/src/py/reactpy/reactpy/testing/common.py index c1eb18ba5..84f3243ae 100644 --- a/src/py/reactpy/reactpy/testing/common.py +++ b/src/py/reactpy/reactpy/testing/common.py @@ -13,7 +13,7 @@ from typing_extensions import ParamSpec from reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT, REACTPY_WEB_MODULES_DIR -from reactpy.core._life_cycle_hook import LifeCycleHook, current_hook +from reactpy.core._life_cycle_hook import LifeCycleHook, get_current_hook from reactpy.core.events import EventHandler, to_event_handler_function @@ -143,7 +143,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: if self is None: raise RuntimeError("Hook catcher has been garbage collected") - hook = current_hook() + hook = get_current_hook() if self.index_by_kwarg is not None: self.index[kwargs[self.index_by_kwarg]] = hook self.latest = hook diff --git a/src/py/reactpy/reactpy/utils.py b/src/py/reactpy/reactpy/utils.py index 5624846a4..06599c445 100644 --- a/src/py/reactpy/reactpy/utils.py +++ b/src/py/reactpy/reactpy/utils.py @@ -27,12 +27,26 @@ class Ref(Generic[_RefValue]): You can compare the contents for two ``Ref`` objects using the ``==`` operator. """ - __slots__ = ("current",) + __slots__ = ("current", "key", "_hook") + + def __init__( + self, initial_value: _RefValue = _UNDEFINED, key: str | None = None + ) -> None: + from reactpy.core._life_cycle_hook import get_current_hook - def __init__(self, initial_value: _RefValue = _UNDEFINED) -> None: if initial_value is not _UNDEFINED: self.current = initial_value """The present value""" + self.key = key + self._hook = None + if key: + hook = get_current_hook() + hook.add_state_update(self) + self._hook = hook + + @property + def value(self) -> _RefValue: + return self.current def set_current(self, new: _RefValue) -> _RefValue: """Set the current value and return what is now the old value @@ -41,6 +55,8 @@ def set_current(self, new: _RefValue) -> _RefValue: """ old = self.current self.current = new + if self.key: + self._hook.add_state_update(self) return old def __eq__(self, other: Any) -> bool: diff --git a/src/py/reactpy/tests/test_core/test_layout.py b/src/py/reactpy/tests/test_core/test_layout.py index cfb544758..276e36e2d 100644 --- a/src/py/reactpy/tests/test_core/test_layout.py +++ b/src/py/reactpy/tests/test_core/test_layout.py @@ -343,7 +343,7 @@ async def test_root_component_life_cycle_hook_is_garbage_collected(): def add_to_live_hooks(constructor): def wrapper(*args, **kwargs): result = constructor(*args, **kwargs) - hook = reactpy.hooks.current_hook() + hook = reactpy.hooks.get_current_hook() hook_id = id(hook) live_hooks.add(hook_id) finalize(hook, live_hooks.discard, hook_id) @@ -375,7 +375,7 @@ async def test_life_cycle_hooks_are_garbage_collected(): def add_to_live_hooks(constructor): def wrapper(*args, **kwargs): result = constructor(*args, **kwargs) - hook = reactpy.hooks.current_hook() + hook = reactpy.hooks.get_current_hook() hook_id = id(hook) live_hooks.add(hook_id) finalize(hook, live_hooks.discard, hook_id) @@ -625,7 +625,7 @@ def Outer(): @reactpy.component def Inner(finalizer_id): if finalizer_id not in registered_finalizers: - hook = reactpy.hooks.current_hook() + hook = reactpy.hooks.get_current_hook() finalize(hook, lambda: garbage_collect_items.append(finalizer_id)) registered_finalizers.add(finalizer_id) return reactpy.html.div(finalizer_id) diff --git a/src/py/reactpy/tests/tooling/hooks.py b/src/py/reactpy/tests/tooling/hooks.py index 1926a93bc..b60040495 100644 --- a/src/py/reactpy/tests/tooling/hooks.py +++ b/src/py/reactpy/tests/tooling/hooks.py @@ -1,8 +1,8 @@ -from reactpy.core.hooks import current_hook, use_state +from reactpy.core.hooks import get_current_hook, use_state def use_force_render(): - return current_hook().schedule_render + return get_current_hook().schedule_render def use_toggle(init=False): From a7374d0fd440bf948f2a5c93d18f4f140ffa09c9 Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Mon, 4 Mar 2024 20:17:50 -0800 Subject: [PATCH 03/11] Follow up work (#3) * Delete commented out code * add serialization for timezone and timedelta * don't show reconnecting layer on first attempt * Only connect after layout update handlers are set up * perf tweak that apparently was never saved * deserialization as well for timezone * timezone and timedelta * alter z-index --- .../@reactpy/client/src/components.tsx | 3 +- .../@reactpy/client/src/reactpy-client.ts | 30 ++++++++++++++++--- src/py/reactpy/reactpy/backend/sanic.py | 10 ------- src/py/reactpy/reactpy/core/state_recovery.py | 25 ++++++++++++++-- src/py/reactpy/reactpy/core/vdom.py | 2 +- 5 files changed, 50 insertions(+), 20 deletions(-) diff --git a/src/js/packages/@reactpy/client/src/components.tsx b/src/js/packages/@reactpy/client/src/components.tsx index fd23d3a8a..0f1c1722d 100644 --- a/src/js/packages/@reactpy/client/src/components.tsx +++ b/src/js/packages/@reactpy/client/src/components.tsx @@ -29,13 +29,12 @@ export function Layout(props: { client: ReactPyClient }): JSX.Element { useEffect( () => - props.client.onMessage("layout-update", ({ path, model, state_vars }) => { + props.client.onLayoutUpdate((path: string, model: any) => { if (path === "") { Object.assign(currentModel, model); } else { setJsonPointer(currentModel, path, model); } - props.client.updateStateVars(state_vars); forceUpdate(); }), [currentModel, props.client], diff --git a/src/js/packages/@reactpy/client/src/reactpy-client.ts b/src/js/packages/@reactpy/client/src/reactpy-client.ts index c5018e9a5..c69db96bc 100644 --- a/src/js/packages/@reactpy/client/src/reactpy-client.ts +++ b/src/js/packages/@reactpy/client/src/reactpy-client.ts @@ -16,6 +16,8 @@ export interface ReactPyClient { */ onMessage(type: string, handler: (message: any) => void): () => void; + onLayoutUpdate(handler: (path: string, model: any) => void): void; + /** * Send a message to the server. * @@ -43,6 +45,7 @@ export abstract class BaseReactPyClient implements ReactPyClient { private resolveReady: (value: undefined) => void; protected stateVars: object; protected debugMessages: boolean; + protected layoutUpdateHandlers: Array<(path: string, model: any) => void> = []; constructor() { this.resolveReady = () => { }; @@ -59,6 +62,10 @@ export abstract class BaseReactPyClient implements ReactPyClient { }; } + onLayoutUpdate(handler: (path: string, model: any) => void): void { + this.layoutUpdateHandlers.push(handler); + } + abstract sendMessage(message: any): void; abstract loadModule(moduleName: string): Promise; @@ -146,7 +153,8 @@ enum messageTypes { isReady = "is-ready", reconnectingCheck = "reconnecting-check", clientState = "client-state", - stateUpdate = "state-update" + stateUpdate = "state-update", + layoutUpdate = "layout-update", }; export class SimpleReactPyClient @@ -198,6 +206,10 @@ export class SimpleReactPyClient this.onMessage(messageTypes.isReady, (msg) => { this.isReady = true; this.salt = msg.salt; }); this.onMessage(messageTypes.clientState, () => { this.sendClientState() }); this.onMessage(messageTypes.stateUpdate, (msg) => { this.updateClientState(msg.state_vars) }); + this.onMessage(messageTypes.layoutUpdate, (msg) => { + this.updateClientState(msg.state_vars); + this.invokeLayoutUpdateHandlers(msg.path, msg.model); + }) this.reconnect() @@ -211,7 +223,13 @@ export class SimpleReactPyClient window.addEventListener('scroll', reconnectOnUserAction); } - showReconnectingGrayout() { + protected invokeLayoutUpdateHandlers(path: string, model: any) { + this.layoutUpdateHandlers.forEach(func => { + func(path, model); + }); + } + + protected showReconnectingGrayout() { const overlay = document.createElement('div'); overlay.id = 'reactpy-reconnect-overlay'; @@ -229,7 +247,7 @@ export class SimpleReactPyClient display: flex; justify-content: center; align-items: center; - z-index: 1000; + z-index: 100000; `; pipeContainer.style.cssText = ` @@ -328,6 +346,10 @@ export class SimpleReactPyClient const maxInterval = this.reconnectOptions?.maxInterval || 20000; const maxRetries = this.reconnectOptions?.maxRetries || 20; + if (this.layoutUpdateHandlers.length == 0) { + setTimeout(() => { this.reconnect(onOpen, interval, connectionAttemptsRemaining, lastAttempt); }, 10); + return + } if (connectionAttemptsRemaining <= 0) { logger.warn("Giving up on reconnecting (hit retry limit)"); @@ -344,7 +366,7 @@ export class SimpleReactPyClient this.shouldReconnect = true; window.setTimeout(() => { - if (!this.didReconnectingCallback && this.reconnectingCallback) { + if (!this.didReconnectingCallback && this.reconnectingCallback && maxRetries != connectionAttemptsRemaining) { this.didReconnectingCallback = true; this.reconnectingCallback(); } diff --git a/src/py/reactpy/reactpy/backend/sanic.py b/src/py/reactpy/reactpy/backend/sanic.py index bad90b072..e648747fa 100644 --- a/src/py/reactpy/reactpy/backend/sanic.py +++ b/src/py/reactpy/reactpy/backend/sanic.py @@ -193,16 +193,6 @@ async def model_stream( ), constructor, ) - # await serve_layout( - # Layout( - # ConnectionContext( - # constructor(), - # value=, - # ) - # ), - # send, - # recv, - # ) api_blueprint.add_websocket_route( model_stream, diff --git a/src/py/reactpy/reactpy/core/state_recovery.py b/src/py/reactpy/reactpy/core/state_recovery.py index 38be33786..6979de7ff 100644 --- a/src/py/reactpy/reactpy/core/state_recovery.py +++ b/src/py/reactpy/reactpy/core/state_recovery.py @@ -1,5 +1,6 @@ import asyncio import base64 +from dataclasses import asdict, is_dataclass import datetime import hashlib import time @@ -58,6 +59,8 @@ def __init__( datetime.datetime, datetime.date, datetime.time, + datetime.timezone, + datetime.timedelta, ] ) @@ -65,7 +68,7 @@ def _map_objects_to_ids(self, serializable_types: Iterable[type]) -> dict: self._object_to_type_id = {} self._type_id_to_object = {} for idx, typ in enumerate( - (None, bool, str, int, float, list, tuple, UUID, *serializable_types) + (None, bool, str, int, float, list, tuple, UUID, datetime.timezone, datetime.timedelta, *serializable_types) ): idx_as_bytes = str(idx).encode("utf-8") self._object_to_type_id[typ] = idx_as_bytes @@ -132,8 +135,13 @@ def __init__( self._type_id_to_object = type_id_to_object self._max_object_length = max_object_length self._max_num_state_objects = max_num_state_objects - self._default_serializer = default_serializer - self._deserializer_map = deserializer_map or {} + self._provided_default_serializer = default_serializer + deserialization_map = { + datetime.timezone: lambda x: datetime.timezone( + datetime.timedelta(**x["offset"]), x["name"] + ), + } + self._deserializer_map = deserialization_map | (deserializer_map or {}) def _get_otp_code(self, target_time: float) -> str: at = self._totp.at @@ -253,6 +261,17 @@ def _sign_serialization( def _serialize_object(self, obj: Any) -> bytes: return orjson.dumps(obj, default=self._default_serializer) + def _default_serializer(self, obj: Any) -> bytes: + if isinstance(obj, datetime.timezone): + return {"name": obj.tzname(None), "offset": obj.utcoffset(None)} + if isinstance(obj, datetime.timedelta): + return {"days": obj.days, "seconds": obj.seconds, "microseconds": obj.microseconds} + if is_dataclass(obj): + return asdict(obj) + if self._provided_default_serializer: + return self._provided_default_serializer(obj) + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + def _do_deserialize( self, typ: type, result: Any, custom_deserializer: Callable | None ) -> Any: diff --git a/src/py/reactpy/reactpy/core/vdom.py b/src/py/reactpy/reactpy/core/vdom.py index e494b5269..2bb3120dd 100644 --- a/src/py/reactpy/reactpy/core/vdom.py +++ b/src/py/reactpy/reactpy/core/vdom.py @@ -328,7 +328,7 @@ def _validate_child_key_integrity(value: Any) -> None: ) else: for child in value: - if isinstance(child, ComponentType) and child.key is None: + if child.key is None and isinstance(child, ComponentType): warn(f"Key not specified for child in list {child}", UserWarning) elif isinstance(child, Mapping) and "key" not in child: # remove 'children' to reduce log spam From ef8537f451362809a80548fc96ef38ce50635923 Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Tue, 5 Mar 2024 12:29:49 -0800 Subject: [PATCH 04/11] add location to serialization (#4) --- src/py/reactpy/reactpy/core/state_recovery.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/py/reactpy/reactpy/core/state_recovery.py b/src/py/reactpy/reactpy/core/state_recovery.py index 6979de7ff..59c8c89da 100644 --- a/src/py/reactpy/reactpy/core/state_recovery.py +++ b/src/py/reactpy/reactpy/core/state_recovery.py @@ -14,6 +14,7 @@ import orjson import pyotp from more_itertools import chunked +from reactpy.backend.types import Location logger = getLogger(__name__) @@ -61,6 +62,7 @@ def __init__( datetime.time, datetime.timezone, datetime.timedelta, + Location, ] ) From e23e43d6da24a3568f157853fd3ce81f96041961 Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Wed, 6 Mar 2024 11:01:29 -0800 Subject: [PATCH 05/11] Fix client not reloading on hash mismatch (#5) --- src/py/reactpy/reactpy/core/serve.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/py/reactpy/reactpy/core/serve.py b/src/py/reactpy/reactpy/core/serve.py index b7832b565..eef5c5c96 100644 --- a/src/py/reactpy/reactpy/core/serve.py +++ b/src/py/reactpy/reactpy/core/serve.py @@ -191,20 +191,20 @@ async def _do_state_rebuild_for_reconnection(self, layout: Layout) -> str: client_state = serializer.deserialize_client_state(state_vars) layout.reconnecting.set_current(True) layout.client_state = client_state - except StateRecoveryFailureError: - logger.exception("State recovery failed") - layout.reconnecting.set_current(False) - layout.client_state = {} - else: + salt = client_state_msg["salt"] - try: layout.start_rendering_for_reconnect() await layout.render_until_queue_empty() except StateRecoveryFailureError: - logger.warning("Client state non-recoverable. Starting fresh") + logger.warning( + "State recovery failed (likely client from different version). Starting fresh" + ) await layout.finish() + layout.reconnecting.set_current(False) + layout.client_state = {} await layout.start() layout.start_rendering() + return salt layout.reconnecting.set_current(False) layout.client_state = {} return salt From d36a982bb144f4b03766c0fa14a5b8947acc2805 Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:35:29 -0800 Subject: [PATCH 06/11] Fix user mouse / scroll activity not actually triggering a reconnection (#6) * Fix for client not reconnecting on activity * Type improvements * comment improvements --- .../@reactpy/client/src/reactpy-client.ts | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/src/js/packages/@reactpy/client/src/reactpy-client.ts b/src/js/packages/@reactpy/client/src/reactpy-client.ts index c69db96bc..005a11473 100644 --- a/src/js/packages/@reactpy/client/src/reactpy-client.ts +++ b/src/js/packages/@reactpy/client/src/reactpy-client.ts @@ -145,8 +145,8 @@ type ReconnectProps = { maxRetries?: number; backoffRate?: number; intervalJitter?: number; - reconnectingCallback?: Function; - reconnectedCallback?: Function; + reconnectingCallback?: () => void; + reconnectedCallback?: () => void; }; enum messageTypes { @@ -163,7 +163,7 @@ export class SimpleReactPyClient private readonly urls: ServerUrls; private socket!: { current?: WebSocket }; private idleDisconnectTimeMillis: number; - private lastMessageTime: number; + private lastActivityTime: number; private reconnectOptions: ReconnectProps | undefined; private messageQueue: any[] = []; private socketLoopIntervalId?: number | null; @@ -174,9 +174,10 @@ export class SimpleReactPyClient private salt: string; private shouldReconnect: boolean; private connectionTimeout: number; - private reconnectingCallback: Function; - private reconnectedCallback: Function; + private reconnectingCallback: () => void; + private reconnectedCallback: () => void; private didReconnectingCallback: boolean; + private willReconnect: boolean; constructor(props: SimpleReactPyClientProps) { super(); @@ -190,11 +191,12 @@ export class SimpleReactPyClient ); this.idleDisconnectTimeMillis = (props.idleDisconnectTimeSeconds || 240) * 1000; this.connectionTimeout = props.connectionTimeout || 5000; - this.lastMessageTime = Date.now() + this.lastActivityTime = Date.now() this.reconnectOptions = props.reconnectOptions this.debugMessages = props.debugMessages || false; this.sleeping = false; this.isReconnecting = false; + this.willReconnect = false; this.isReady = false this.salt = ""; this.shouldReconnect = false; @@ -209,18 +211,21 @@ export class SimpleReactPyClient this.onMessage(messageTypes.layoutUpdate, (msg) => { this.updateClientState(msg.state_vars); this.invokeLayoutUpdateHandlers(msg.path, msg.model); + this.willReconnect = true; // don't indicate a reconnect until at least one successful layout update }) this.reconnect() - const reconnectOnUserAction = (ev: any) => { + const handleUserAction = (ev: any) => { + this.lastActivityTime = Date.now(); if (!this.isReady && !this.isReconnecting) { + this.sleeping = false; this.reconnect(); } } - window.addEventListener('mousemove', reconnectOnUserAction); - window.addEventListener('scroll', reconnectOnUserAction); + window.addEventListener('mousemove', handleUserAction); + window.addEventListener('scroll', handleUserAction); } protected invokeLayoutUpdateHandlers(path: string, model: any) { @@ -290,7 +295,7 @@ export class SimpleReactPyClient } indicateReconnect(): void { - const isReconnecting = this.isReconnecting ? "yes" : "no"; + const isReconnecting = this.willReconnect ? "yes" : "no"; this.sendMessage({ "type": messageTypes.reconnectingCheck, "value": isReconnecting }, true) } @@ -324,6 +329,7 @@ export class SimpleReactPyClient if (this.debugMessages) { logger.log("Sending message", message); } + this.lastActivityTime = Date.now(); this.socket.current.send(JSON.stringify(message)); } } @@ -331,10 +337,11 @@ export class SimpleReactPyClient idleTimeoutCheck(): void { if (!this.socket) return; - if (Date.now() - this.lastMessageTime > this.idleDisconnectTimeMillis) { + if (Date.now() - this.lastActivityTime > this.idleDisconnectTimeMillis) { if (this.socket.current && this.socket.current.readyState === WebSocket.OPEN) { logger.warn("Closing socket connection due to idle activity"); this.sleeping = true; + this.isReconnecting = false; this.socket.current.close(); } } @@ -388,14 +395,15 @@ export class SimpleReactPyClient onOpen(); }, onClose: () => { - // reset retry interval + // reset retry interval on successful connection if (Date.now() - lastAttempt > maxInterval * 2) { interval = 750; connectionAttemptsRemaining = maxRetries; + } else if (!this.sleeping) { + this.isReconnecting = true; } lastAttempt = Date.now() this.shouldReconnect = false; - this.isReconnecting = true; this.isReady = false; if (this.socketLoopIntervalId) clearInterval(this.socketLoopIntervalId); @@ -410,7 +418,7 @@ export class SimpleReactPyClient this.reconnect(onOpen, thisInterval, newRetriesRemaining, lastAttempt); } }, - onMessage: async ({ data }) => { this.lastMessageTime = Date.now(); this.handleIncoming(JSON.parse(data)) }, + onMessage: async ({ data }) => { this.lastActivityTime = Date.now(); this.handleIncoming(JSON.parse(data)) }, ...this.reconnectOptions, }); this.socketLoopIntervalId = window.setInterval(() => { this.socketLoop() }, 30); @@ -431,7 +439,7 @@ export class SimpleReactPyClient } else { this.messageQueue.push(message); } - this.lastMessageTime = Date.now() + this.lastActivityTime = Date.now() this.sleeping = false; this.ensureConnected(); } From 6118197685ee7a454d8043656cd855533c5c3c9d Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:57:37 -0800 Subject: [PATCH 07/11] fix state recovery error not propagating (#7) Fix Exception catch-all in render swallowing state recovery errors Fix serialization errors not resulting in a state recovery error --- src/py/reactpy/reactpy/core/layout.py | 7 ++++++- src/py/reactpy/reactpy/core/serve.py | 5 ++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/py/reactpy/reactpy/core/layout.py b/src/py/reactpy/reactpy/core/layout.py index 8db467e30..f2274c6a8 100644 --- a/src/py/reactpy/reactpy/core/layout.py +++ b/src/py/reactpy/reactpy/core/layout.py @@ -47,7 +47,10 @@ ) from reactpy.core.component import Component from reactpy.core.hooks import _ContextProvider -from reactpy.core.state_recovery import StateRecoverySerializer +from reactpy.core.state_recovery import ( + StateRecoveryFailureError, + StateRecoverySerializer, +) from reactpy.core.types import ( ComponentType, EventHandlerDict, @@ -284,6 +287,8 @@ async def _render_component( # components are given a node in the tree some other way wrapper_model: VdomDict = {"tagName": "", "children": [raw_model]} await self._render_model(exit_stack, old_state, new_state, wrapper_model) + except StateRecoveryFailureError: + raise except Exception as error: logger.exception(f"Failed to render {component}") new_state.model.current = { diff --git a/src/py/reactpy/reactpy/core/serve.py b/src/py/reactpy/reactpy/core/serve.py index eef5c5c96..85799c762 100644 --- a/src/py/reactpy/reactpy/core/serve.py +++ b/src/py/reactpy/reactpy/core/serve.py @@ -188,7 +188,10 @@ async def _do_state_rebuild_for_reconnection(self, layout: Layout) -> str: serializer = self._state_recovery_manager.create_serializer( client_state_msg["salt"] ) - client_state = serializer.deserialize_client_state(state_vars) + try: + client_state = serializer.deserialize_client_state(state_vars) + except Exception as err: + raise StateRecoveryFailureError() from err layout.reconnecting.set_current(True) layout.client_state = client_state From 09eb44cc679f0364a8f0819127ca0ff87bbbde62 Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Fri, 8 Mar 2024 07:25:29 -0800 Subject: [PATCH 08/11] Move socket loop interval to a variable (#8) --- src/js/packages/@reactpy/client/src/reactpy-client.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/js/packages/@reactpy/client/src/reactpy-client.ts b/src/js/packages/@reactpy/client/src/reactpy-client.ts index 005a11473..b840479a0 100644 --- a/src/js/packages/@reactpy/client/src/reactpy-client.ts +++ b/src/js/packages/@reactpy/client/src/reactpy-client.ts @@ -107,6 +107,7 @@ export type SimpleReactPyClientProps = { idleDisconnectTimeSeconds?: number; connectionTimeout?: number; debugMessages?: boolean; + socketLoopThrottle?: number; }; /** @@ -178,6 +179,7 @@ export class SimpleReactPyClient private reconnectedCallback: () => void; private didReconnectingCallback: boolean; private willReconnect: boolean; + private socketLoopThrottle: number; constructor(props: SimpleReactPyClientProps) { super(); @@ -194,6 +196,7 @@ export class SimpleReactPyClient this.lastActivityTime = Date.now() this.reconnectOptions = props.reconnectOptions this.debugMessages = props.debugMessages || false; + this.socketLoopThrottle = props.socketLoopThrottle || 5; this.sleeping = false; this.isReconnecting = false; this.willReconnect = false; @@ -421,7 +424,7 @@ export class SimpleReactPyClient onMessage: async ({ data }) => { this.lastActivityTime = Date.now(); this.handleIncoming(JSON.parse(data)) }, ...this.reconnectOptions, }); - this.socketLoopIntervalId = window.setInterval(() => { this.socketLoop() }, 30); + this.socketLoopIntervalId = window.setInterval(() => { this.socketLoop() }, this.socketLoopThrottle); this.idleCheckIntervalId = window.setInterval(() => { this.idleTimeoutCheck() }, 10000); }, interval) From 5f5562f8ac18a00229d5cd612de48dfbd41d136c Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Tue, 12 Mar 2024 10:29:08 -0700 Subject: [PATCH 09/11] mypy typing updates (#9) --- src/py/reactpy/reactpy/core/component.py | 6 +++++- src/py/reactpy/reactpy/core/hooks.py | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py index 9d4955546..f11bada12 100644 --- a/src/py/reactpy/reactpy/core/component.py +++ b/src/py/reactpy/reactpy/core/component.py @@ -2,7 +2,7 @@ import inspect from functools import wraps -from typing import Any, Callable, ParamSpec, TypeVar +from typing import Any, Callable, ParamSpec, TypeVar, overload from reactpy.core.types import ComponentType, VdomDict @@ -10,6 +10,10 @@ P = ParamSpec("P") +@overload +def component(function: None = None, *, priority: int) -> Callable[P, Component]: ... + + def component( function: Callable[P, T] | None = None, *, diff --git a/src/py/reactpy/reactpy/core/hooks.py b/src/py/reactpy/reactpy/core/hooks.py index 72f7c2c4e..95fefaf94 100644 --- a/src/py/reactpy/reactpy/core/hooks.py +++ b/src/py/reactpy/reactpy/core/hooks.py @@ -54,11 +54,11 @@ class ReconnectingOnly(list): @overload -def use_state(initial_value: Callable[[], _Type]) -> State[_Type]: ... +def use_state(initial_value: Callable[[], _Type], *, server_only: bool = False) -> State[_Type]: ... @overload -def use_state(initial_value: _Type) -> State[_Type]: ... +def use_state(initial_value: _Type, *, server_only: bool = False) -> State[_Type]: ... def use_state( @@ -509,7 +509,7 @@ def empty(self) -> bool: return False -def use_ref(initial_value: _Type, server_only: bool = True) -> Ref[_Type]: +def use_ref(initial_value: _Type, *, server_only: bool = True) -> Ref[_Type]: """See the full :ref:`Use State` docs for details Parameters: From 6203da576411f864e3e173e343a363a5a46e23e6 Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Tue, 12 Mar 2024 10:34:34 -0700 Subject: [PATCH 10/11] add overload (#10) --- src/py/reactpy/reactpy/core/component.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py index f11bada12..71206b94d 100644 --- a/src/py/reactpy/reactpy/core/component.py +++ b/src/py/reactpy/reactpy/core/component.py @@ -14,6 +14,10 @@ def component(function: None = None, *, priority: int) -> Callable[P, Component]: ... +@overload +def component(function: Callable[P, T] | None) -> Callable[P, Component]: ... + + def component( function: Callable[P, T] | None = None, *, From 9704cbccc189bd9cdfac88bbc84d8c5f9bd46af3 Mon Sep 17 00:00:00 2001 From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com> Date: Tue, 12 Mar 2024 11:00:33 -0700 Subject: [PATCH 11/11] finally fix priority typing (#11) --- src/py/reactpy/reactpy/core/component.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py index 71206b94d..8ffaae1cf 100644 --- a/src/py/reactpy/reactpy/core/component.py +++ b/src/py/reactpy/reactpy/core/component.py @@ -11,11 +11,13 @@ @overload -def component(function: None = None, *, priority: int) -> Callable[P, Component]: ... +def component( + function: None = None, *, priority: int +) -> Callable[[Callable[P, T]], Callable[P, Component]]: ... @overload -def component(function: Callable[P, T] | None) -> Callable[P, Component]: ... +def component(function: Callable[P, T]) -> Callable[P, Component]: ... def component(