From 8f6c8ec01da8ad4845d0a5493a33e5eca84bfcac Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Mon, 4 Mar 2024 17:44:54 -0800
Subject: [PATCH 01/11] Fix static type hints for component decorator (#2)
* Fix component decorator eating static type hints
---
src/py/reactpy/reactpy/core/component.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py
index f825aac71..e3c6b068d 100644
--- a/src/py/reactpy/reactpy/core/component.py
+++ b/src/py/reactpy/reactpy/core/component.py
@@ -2,14 +2,17 @@
import inspect
from functools import wraps
-from typing import Any, Callable
+from typing import Any, Callable, TypeVar, ParamSpec
from reactpy.core.types import ComponentType, VdomDict
+T = TypeVar("T", bound=ComponentType | VdomDict | str | None)
+P = ParamSpec("P")
+
def component(
- function: Callable[..., ComponentType | VdomDict | str | None]
-) -> Callable[..., Component]:
+ function: Callable[P, T],
+) -> Callable[P, Component]:
"""A decorator for defining a new component.
Parameters:
@@ -25,7 +28,7 @@ def component(
raise TypeError(msg)
@wraps(function)
- def constructor(*args: Any, key: Any | None = None, **kwargs: Any) -> Component:
+ def constructor(*args: P.args, key: Any | None = None, **kwargs: P.kwargs) -> Component:
return Component(function, key, args, kwargs, sig)
return constructor
From efa547f4606a6d5f8447a901ac12f880a7e34f28 Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Mon, 4 Mar 2024 18:09:10 -0800
Subject: [PATCH 02/11] Connection resume (#1)
Add reconnection and client state side state capabilities
---
src/js/package-lock.json | 4 +-
.../@reactpy/client/src/components.tsx | 27 +-
.../packages/@reactpy/client/src/messages.ts | 9 +-
.../@reactpy/client/src/reactpy-client.ts | 336 +++++++++++++++---
src/py/reactpy/pyproject.toml | 3 +
src/py/reactpy/reactpy/__init__.py | 10 +-
src/py/reactpy/reactpy/backend/hooks.py | 17 +-
src/py/reactpy/reactpy/backend/sanic.py | 61 ++--
.../reactpy/reactpy/core/_life_cycle_hook.py | 59 ++-
src/py/reactpy/reactpy/core/_thread_local.py | 21 --
src/py/reactpy/reactpy/core/component.py | 51 ++-
src/py/reactpy/reactpy/core/hooks.py | 103 +++++-
src/py/reactpy/reactpy/core/layout.py | 243 +++++++++++--
src/py/reactpy/reactpy/core/serve.py | 179 ++++++++--
src/py/reactpy/reactpy/core/state_recovery.py | 276 ++++++++++++++
src/py/reactpy/reactpy/core/types.py | 37 ++
src/py/reactpy/reactpy/testing/common.py | 4 +-
src/py/reactpy/reactpy/utils.py | 20 +-
src/py/reactpy/tests/test_core/test_layout.py | 6 +-
src/py/reactpy/tests/tooling/hooks.py | 4 +-
20 files changed, 1245 insertions(+), 225 deletions(-)
delete mode 100644 src/py/reactpy/reactpy/core/_thread_local.py
create mode 100644 src/py/reactpy/reactpy/core/state_recovery.py
diff --git a/src/js/package-lock.json b/src/js/package-lock.json
index 2edfdd260..2904bba0e 100644
--- a/src/js/package-lock.json
+++ b/src/js/package-lock.json
@@ -28,7 +28,7 @@
"@types/react": "^17.0",
"@types/react-dom": "^17.0",
"typescript": "^4.9.5",
- "vite": "^3.1.8"
+ "vite": "^3.2.7"
}
},
"app/node_modules/@reactpy/client": {
@@ -3955,7 +3955,7 @@
"@types/react-dom": "^17.0",
"preact": "^10.7.0",
"typescript": "^4.9.5",
- "vite": "^3.1.8"
+ "vite": "^3.2.7"
},
"dependencies": {
"@reactpy/client": {
diff --git a/src/js/packages/@reactpy/client/src/components.tsx b/src/js/packages/@reactpy/client/src/components.tsx
index 728c4cec7..fd23d3a8a 100644
--- a/src/js/packages/@reactpy/client/src/components.tsx
+++ b/src/js/packages/@reactpy/client/src/components.tsx
@@ -29,12 +29,13 @@ export function Layout(props: { client: ReactPyClient }): JSX.Element {
useEffect(
() =>
- props.client.onMessage("layout-update", ({ path, model }) => {
+ props.client.onMessage("layout-update", ({ path, model, state_vars }) => {
if (path === "") {
Object.assign(currentModel, model);
} else {
setJsonPointer(currentModel, path, model);
}
+ props.client.updateStateVars(state_vars);
forceUpdate();
}),
[currentModel, props.client],
@@ -125,23 +126,15 @@ function ScriptElement({ model }: { model: ReactPyVdom }) {
(value): value is string => typeof value == "string",
)[0];
- let scriptElement: HTMLScriptElement;
- if (model.attributes) {
- scriptElement = document.createElement("script");
- for (const [k, v] of Object.entries(model.attributes)) {
- scriptElement.setAttribute(k, v);
- }
- if (scriptContent) {
- scriptElement.appendChild(document.createTextNode(scriptContent));
- }
- ref.current.appendChild(scriptElement);
- } else if (scriptContent) {
- const scriptResult = eval(scriptContent);
- if (typeof scriptResult == "function") {
- return scriptResult();
- }
+ const scriptElement: HTMLScriptElement = document.createElement("script");
+ for (const [k, v] of Object.entries(model.attributes || {})) {
+ scriptElement.setAttribute(k, v);
+ }
+ if (scriptContent) {
+ scriptElement.appendChild(document.createTextNode(scriptContent));
}
- }, [model.key, ref.current]);
+ ref.current.appendChild(scriptElement);
+ }, [model.key]);
return
;
}
diff --git a/src/js/packages/@reactpy/client/src/messages.ts b/src/js/packages/@reactpy/client/src/messages.ts
index 34001dcb0..5fbfc24bf 100644
--- a/src/js/packages/@reactpy/client/src/messages.ts
+++ b/src/js/packages/@reactpy/client/src/messages.ts
@@ -12,6 +12,11 @@ export type LayoutEventMessage = {
data: any;
};
-export type IncomingMessage = LayoutUpdateMessage;
-export type OutgoingMessage = LayoutEventMessage;
+export type ReconnectingCheckMessage = {
+ type: "reconnecting-check";
+ value: string;
+}
+
+export type IncomingMessage = LayoutUpdateMessage | ReconnectingCheckMessage;
+export type OutgoingMessage = LayoutEventMessage | ReconnectingCheckMessage;
export type Message = IncomingMessage | OutgoingMessage;
diff --git a/src/js/packages/@reactpy/client/src/reactpy-client.ts b/src/js/packages/@reactpy/client/src/reactpy-client.ts
index 6f37b55a1..c5018e9a5 100644
--- a/src/js/packages/@reactpy/client/src/reactpy-client.ts
+++ b/src/js/packages/@reactpy/client/src/reactpy-client.ts
@@ -29,16 +29,26 @@ export interface ReactPyClient {
* @returns A promise that resolves to the module.
*/
loadModule(moduleName: string): Promise;
+
+ /**
+ * Update state vars from the server for reconnections
+ * @param givenStateVars State vars to store
+ */
+ updateStateVars(givenStateVars: object): void;
}
export abstract class BaseReactPyClient implements ReactPyClient {
private readonly handlers: { [key: string]: ((message: any) => void)[] } = {};
protected readonly ready: Promise;
private resolveReady: (value: undefined) => void;
+ protected stateVars: object;
+ protected debugMessages: boolean;
constructor() {
- this.resolveReady = () => {};
+ this.resolveReady = () => { };
this.ready = new Promise((resolve) => (this.resolveReady = resolve));
+ this.stateVars = {};
+ this.debugMessages = false;
}
onMessage(type: string, handler: (message: any) => void): () => void {
@@ -52,6 +62,10 @@ export abstract class BaseReactPyClient implements ReactPyClient {
abstract sendMessage(message: any): void;
abstract loadModule(moduleName: string): Promise;
+ updateStateVars(givenStateVars: object): void {
+ this.stateVars = Object.assign(this.stateVars, givenStateVars);
+ }
+
/**
* Handle an incoming message.
*
@@ -65,6 +79,10 @@ export abstract class BaseReactPyClient implements ReactPyClient {
return;
}
+ if (this.debugMessages) {
+ logger.log("Got message", message);
+ }
+
const messageHandlers: ((m: any) => void)[] | undefined =
this.handlers[message.type];
if (!messageHandlers) {
@@ -79,6 +97,9 @@ export abstract class BaseReactPyClient implements ReactPyClient {
export type SimpleReactPyClientProps = {
serverLocation?: LocationProps;
reconnectOptions?: ReconnectProps;
+ idleDisconnectTimeSeconds?: number;
+ connectionTimeout?: number;
+ debugMessages?: boolean;
};
/**
@@ -117,14 +138,37 @@ type ReconnectProps = {
maxRetries?: number;
backoffRate?: number;
intervalJitter?: number;
+ reconnectingCallback?: Function;
+ reconnectedCallback?: Function;
+};
+
+enum messageTypes {
+ isReady = "is-ready",
+ reconnectingCheck = "reconnecting-check",
+ clientState = "client-state",
+ stateUpdate = "state-update"
};
export class SimpleReactPyClient
extends BaseReactPyClient
- implements ReactPyClient
-{
+ implements ReactPyClient {
private readonly urls: ServerUrls;
- private readonly socket: { current?: WebSocket };
+ private socket!: { current?: WebSocket };
+ private idleDisconnectTimeMillis: number;
+ private lastMessageTime: number;
+ private reconnectOptions: ReconnectProps | undefined;
+ private messageQueue: any[] = [];
+ private socketLoopIntervalId?: number | null;
+ private idleCheckIntervalId?: number | null;
+ private sleeping: boolean;
+ private isReconnecting: boolean;
+ private isReady: boolean;
+ private salt: string;
+ private shouldReconnect: boolean;
+ private connectionTimeout: number;
+ private reconnectingCallback: Function;
+ private reconnectedCallback: Function;
+ private didReconnectingCallback: boolean;
constructor(props: SimpleReactPyClientProps) {
super();
@@ -136,17 +180,238 @@ export class SimpleReactPyClient
query: document.location.search,
},
);
+ this.idleDisconnectTimeMillis = (props.idleDisconnectTimeSeconds || 240) * 1000;
+ this.connectionTimeout = props.connectionTimeout || 5000;
+ this.lastMessageTime = Date.now()
+ this.reconnectOptions = props.reconnectOptions
+ this.debugMessages = props.debugMessages || false;
+ this.sleeping = false;
+ this.isReconnecting = false;
+ this.isReady = false
+ this.salt = "";
+ this.shouldReconnect = false;
+ this.didReconnectingCallback = false;
+ this.reconnectingCallback = props.reconnectOptions?.reconnectingCallback || this.showReconnectingGrayout;
+ this.reconnectedCallback = props.reconnectOptions?.reconnectedCallback || this.hideReconnectingGrayout;
+
+ this.onMessage(messageTypes.reconnectingCheck, () => { this.indicateReconnect() })
+ this.onMessage(messageTypes.isReady, (msg) => { this.isReady = true; this.salt = msg.salt; });
+ this.onMessage(messageTypes.clientState, () => { this.sendClientState() });
+ this.onMessage(messageTypes.stateUpdate, (msg) => { this.updateClientState(msg.state_vars) });
+
+ this.reconnect()
+
+ const reconnectOnUserAction = (ev: any) => {
+ if (!this.isReady && !this.isReconnecting) {
+ this.reconnect();
+ }
+ }
+
+ window.addEventListener('mousemove', reconnectOnUserAction);
+ window.addEventListener('scroll', reconnectOnUserAction);
+ }
+
+ showReconnectingGrayout() {
+ const overlay = document.createElement('div');
+ overlay.id = 'reactpy-reconnect-overlay';
+
+ const pipeContainer = document.createElement('div');
+ const pipeSymbol = document.createElement('div');
+ pipeSymbol.textContent = '|'; // Set the pipe symbol
+
+ overlay.style.cssText = `
+ position: fixed;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background-color: rgba(0,0,0,0.5);
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ z-index: 1000;
+ `;
+
+ pipeContainer.style.cssText = `
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ width: 40px;
+ height: 40px;
+ `;
+
+ pipeSymbol.style.cssText = `
+ font-size: 24px;
+ color: #FFF;
+ display: inline-block;
+ width: 100%;
+ height: 100%;
+ text-align: center;
+ transform-origin: center;
+ `;
+
+ pipeContainer.appendChild(pipeSymbol);
+ overlay.appendChild(pipeContainer);
+ document.body.appendChild(overlay);
+
+ // Create and start the spin animation
+ let angle = 0;
+ function spin() {
+ angle = (angle + 2) % 360; // Adjust rotation speed as needed
+ pipeSymbol.style.transform = `rotate(${angle}deg)`;
+ requestAnimationFrame(spin);
+ }
+ spin();
+ }
+
+ hideReconnectingGrayout() {
+ const overlay = document.getElementById('reactpy-reconnect-overlay');
+ if (overlay && overlay.parentNode) {
+ overlay.parentNode.removeChild(overlay);
+ }
+ }
- this.socket = createReconnectingWebSocket({
- readyPromise: this.ready,
- url: this.urls.stream,
- onMessage: async ({ data }) => this.handleIncoming(JSON.parse(data)),
- ...props.reconnectOptions,
+ indicateReconnect(): void {
+ const isReconnecting = this.isReconnecting ? "yes" : "no";
+ this.sendMessage({ "type": messageTypes.reconnectingCheck, "value": isReconnecting }, true)
+ }
+
+ sendClientState(): void {
+ if (!this.socket)
+ return;
+ this.transmitMessage({
+ "type": "client-state",
+ "value": this.stateVars,
+ "salt": this.salt
});
}
- sendMessage(message: any): void {
- this.socket.current?.send(JSON.stringify(message));
+ updateClientState(stateVars: object): void {
+ if (!this.socket)
+ return;
+ this.updateStateVars(stateVars)
+ }
+
+ socketLoop(): void {
+ if (!this.socket)
+ return;
+ if (this.messageQueue.length > 0 && this.isReady && this.socket.current && this.socket.current.readyState === WebSocket.OPEN) {
+ const message = this.messageQueue.shift(); // Remove the first message from the queue
+ this.transmitMessage(message);
+ }
+ }
+
+ transmitMessage(message: any): void {
+ if (this.socket && this.socket.current) {
+ if (this.debugMessages) {
+ logger.log("Sending message", message);
+ }
+ this.socket.current.send(JSON.stringify(message));
+ }
+ }
+
+ idleTimeoutCheck(): void {
+ if (!this.socket)
+ return;
+ if (Date.now() - this.lastMessageTime > this.idleDisconnectTimeMillis) {
+ if (this.socket.current && this.socket.current.readyState === WebSocket.OPEN) {
+ logger.warn("Closing socket connection due to idle activity");
+ this.sleeping = true;
+ this.socket.current.close();
+ }
+ }
+ }
+
+ reconnect(onOpen?: () => void, interval: number = 750, connectionAttemptsRemaining: number = 20, lastAttempt: number = 0): void {
+ const intervalJitter = this.reconnectOptions?.intervalJitter || 0.5;
+ const backoffRate = this.reconnectOptions?.backoffRate || 1.2;
+ const maxInterval = this.reconnectOptions?.maxInterval || 20000;
+ const maxRetries = this.reconnectOptions?.maxRetries || 20;
+
+
+ if (connectionAttemptsRemaining <= 0) {
+ logger.warn("Giving up on reconnecting (hit retry limit)");
+ this.shouldReconnect = false;
+ this.isReconnecting = false;
+ return
+ }
+
+ if (this.shouldReconnect) {
+ // already reconnecting
+ return;
+ }
+ lastAttempt = lastAttempt || Date.now();
+ this.shouldReconnect = true;
+
+ window.setTimeout(() => {
+ if (!this.didReconnectingCallback && this.reconnectingCallback) {
+ this.didReconnectingCallback = true;
+ this.reconnectingCallback();
+ }
+
+ if (maxRetries < connectionAttemptsRemaining)
+ connectionAttemptsRemaining = maxRetries;
+
+ this.socket = createWebSocket({
+ connectionTimeout: this.connectionTimeout,
+ readyPromise: this.ready,
+ url: this.urls.stream,
+ onOpen: () => {
+ lastAttempt = Date.now();
+ if (this.reconnectedCallback) {
+ this.reconnectedCallback();
+ this.didReconnectingCallback = false;
+ }
+ if (onOpen)
+ onOpen();
+ },
+ onClose: () => {
+ // reset retry interval
+ if (Date.now() - lastAttempt > maxInterval * 2) {
+ interval = 750;
+ connectionAttemptsRemaining = maxRetries;
+ }
+ lastAttempt = Date.now()
+ this.shouldReconnect = false;
+ this.isReconnecting = true;
+ this.isReady = false;
+ if (this.socketLoopIntervalId)
+ clearInterval(this.socketLoopIntervalId);
+ if (this.idleCheckIntervalId)
+ clearInterval(this.idleCheckIntervalId);
+ if (!this.sleeping) {
+ const thisInterval = nextInterval(addJitter(interval, intervalJitter), backoffRate, maxInterval);
+ const newRetriesRemaining = connectionAttemptsRemaining - 1;
+ logger.log(
+ `reconnecting in ${(thisInterval / 1000).toPrecision(4)} seconds... (${newRetriesRemaining} retries remaining)`,
+ );
+ this.reconnect(onOpen, thisInterval, newRetriesRemaining, lastAttempt);
+ }
+ },
+ onMessage: async ({ data }) => { this.lastMessageTime = Date.now(); this.handleIncoming(JSON.parse(data)) },
+ ...this.reconnectOptions,
+ });
+ this.socketLoopIntervalId = window.setInterval(() => { this.socketLoop() }, 30);
+ this.idleCheckIntervalId = window.setInterval(() => { this.idleTimeoutCheck() }, 10000);
+
+ }, interval)
+ }
+
+ ensureConnected(): void {
+ if (this.socket.current?.readyState == WebSocket.CLOSED) {
+ this.reconnect();
+ }
+ }
+
+ sendMessage(message: any, immediate: boolean = false): void {
+ if (immediate) {
+ this.transmitMessage(message);
+ } else {
+ this.messageQueue.push(message);
+ }
+ this.lastMessageTime = Date.now()
+ this.sleeping = false;
+ this.ensureConnected();
}
loadModule(moduleName: string): Promise {
@@ -173,25 +438,16 @@ function getServerUrls(props: LocationProps): ServerUrls {
return { base, modules, assets, stream };
}
-function createReconnectingWebSocket(
+function createWebSocket(
props: {
url: string;
readyPromise: Promise;
+ connectionTimeout: number;
onOpen?: () => void;
onMessage: (message: MessageEvent) => void;
onClose?: () => void;
- } & ReconnectProps,
+ },
) {
- const {
- maxInterval = 60000,
- maxRetries = 50,
- backoffRate = 1.1,
- intervalJitter = 0.1,
- } = props;
-
- const startInterval = 750;
- let retries = 0;
- let interval = startInterval;
const closed = false;
let everConnected = false;
const socket: { current?: WebSocket } = {};
@@ -201,11 +457,18 @@ function createReconnectingWebSocket(
return;
}
socket.current = new WebSocket(props.url);
+ const connectionTimeout = props.connectionTimeout; // Timeout in milliseconds
+
+ const timeoutId = setTimeout(() => {
+ if (socket.current && socket.current.readyState !== WebSocket.OPEN) {
+ socket.current.close();
+ console.error('Connection attempt timed out');
+ }
+ }, connectionTimeout);
socket.current.onopen = () => {
+ clearTimeout(timeoutId);
everConnected = true;
logger.log("client connected");
- interval = startInterval;
- retries = 0;
if (props.onOpen) {
props.onOpen();
}
@@ -214,25 +477,12 @@ function createReconnectingWebSocket(
socket.current.onclose = () => {
if (!everConnected) {
logger.log("failed to connect");
- return;
+ } else {
+ logger.log("client disconnected");
}
-
- logger.log("client disconnected");
if (props.onClose) {
props.onClose();
}
-
- if (retries >= maxRetries) {
- return;
- }
-
- const thisInterval = addJitter(interval, intervalJitter);
- logger.log(
- `reconnecting in ${(thisInterval / 1000).toPrecision(4)} seconds...`,
- );
- setTimeout(connect, thisInterval);
- interval = nextInterval(interval, backoffRate, maxInterval);
- retries++;
};
};
@@ -247,16 +497,16 @@ function nextInterval(
maxInterval: number,
): number {
return Math.min(
- currentInterval *
+ (currentInterval *
// increase interval by backoff rate
- backoffRate,
+ backoffRate),
// don't exceed max interval
maxInterval,
);
}
function addJitter(interval: number, jitter: number): number {
- return interval + (Math.random() * jitter * interval * 2 - jitter * interval);
+ return interval + (Math.random() * jitter * interval);
}
function rtrim(text: string, trim: string): string {
diff --git a/src/py/reactpy/pyproject.toml b/src/py/reactpy/pyproject.toml
index 309248507..c16e5f065 100644
--- a/src/py/reactpy/pyproject.toml
+++ b/src/py/reactpy/pyproject.toml
@@ -35,6 +35,9 @@ dependencies = [
"colorlog >=6",
"asgiref >=3",
"lxml >=4",
+ "pyotp",
+ "orjson",
+ "more-itertools",
]
[project.optional-dependencies]
all = ["reactpy[starlette,sanic,fastapi,flask,tornado,testing]"]
diff --git a/src/py/reactpy/reactpy/__init__.py b/src/py/reactpy/reactpy/__init__.py
index 49e357441..32b4712a1 100644
--- a/src/py/reactpy/reactpy/__init__.py
+++ b/src/py/reactpy/reactpy/__init__.py
@@ -1,10 +1,16 @@
from reactpy import backend, config, html, logging, sample, svg, types, web, widgets
-from reactpy.backend.hooks import use_connection, use_location, use_scope
+from reactpy.backend.hooks import (
+ use_connection,
+ use_location,
+ use_reconnect_effect,
+ use_scope,
+)
from reactpy.backend.utils import run
from reactpy.core import hooks
from reactpy.core.component import component
from reactpy.core.events import event
from reactpy.core.hooks import (
+ ReconnectingOnly,
create_context,
use_callback,
use_context,
@@ -33,6 +39,7 @@
"html",
"Layout",
"logging",
+ "ReconnectingOnly",
"Ref",
"run",
"sample",
@@ -46,6 +53,7 @@
"use_effect",
"use_location",
"use_memo",
+ "use_reconnect_effect",
"use_reducer",
"use_ref",
"use_scope",
diff --git a/src/py/reactpy/reactpy/backend/hooks.py b/src/py/reactpy/reactpy/backend/hooks.py
index ee4ce1b5c..3424b9b86 100644
--- a/src/py/reactpy/reactpy/backend/hooks.py
+++ b/src/py/reactpy/reactpy/backend/hooks.py
@@ -1,10 +1,16 @@
from __future__ import annotations
from collections.abc import MutableMapping
-from typing import Any
+from typing import Any, Callable
from reactpy.backend.types import Connection, Location
-from reactpy.core.hooks import create_context, use_context
+from reactpy.core.hooks import (
+ ReconnectingOnly,
+ _EffectApplyFunc,
+ create_context,
+ use_context,
+ use_effect,
+)
from reactpy.core.types import Context
# backend implementations should establish this context at the root of an app
@@ -28,3 +34,10 @@ def use_scope() -> MutableMapping[str, Any]:
def use_location() -> Location:
"""Get the current :class:`~reactpy.backend.types.Connection`'s location."""
return use_connection().location
+
+
+def use_reconnect_effect(
+ function: _EffectApplyFunc | None = None,
+) -> Callable[[_EffectApplyFunc], None] | None:
+ """Apply an effect only on reconnection"""
+ return use_effect(function, ReconnectingOnly())
diff --git a/src/py/reactpy/reactpy/backend/sanic.py b/src/py/reactpy/reactpy/backend/sanic.py
index 76eb0423e..bad90b072 100644
--- a/src/py/reactpy/reactpy/backend/sanic.py
+++ b/src/py/reactpy/reactpy/backend/sanic.py
@@ -7,6 +7,7 @@
from typing import Any
from urllib import parse as urllib_parse
from uuid import uuid4
+import orjson
from sanic import Blueprint, Sanic, request, response
from sanic.config import Config
@@ -24,11 +25,15 @@
safe_web_modules_dir_path,
serve_with_uvicorn,
)
-from reactpy.backend.hooks import ConnectionContext
from reactpy.backend.hooks import use_connection as _use_connection
from reactpy.backend.types import Connection, Location
-from reactpy.core.layout import Layout
-from reactpy.core.serve import RecvCoroutine, SendCoroutine, Stop, serve_layout
+from reactpy.core.serve import (
+ RecvCoroutine,
+ SendCoroutine,
+ Stop,
+ WebsocketServer,
+)
+from reactpy.core.state_recovery import StateRecoveryManager
from reactpy.core.types import RootComponentConstructor
logger = logging.getLogger(__name__)
@@ -51,6 +56,7 @@ def configure(
app: Sanic[Any, Any],
component: RootComponentConstructor,
options: Options | None = None,
+ state_recovery_manager: StateRecoveryManager | None = None,
) -> None:
"""Configure an application instance to display the given component"""
options = options or Options()
@@ -59,7 +65,9 @@ def configure(
api_bp = Blueprint(f"reactpy_api_{id(app)}", url_prefix=str(PATH_PREFIX))
_setup_common_routes(api_bp, spa_bp, options)
- _setup_single_view_dispatcher_route(api_bp, component, options)
+ _setup_single_view_dispatcher_route(
+ api_bp, component, options, state_recovery_manager
+ )
app.blueprint([spa_bp, api_bp])
@@ -159,6 +167,7 @@ def _setup_single_view_dispatcher_route(
api_blueprint: Blueprint,
constructor: RootComponentConstructor,
options: Options,
+ state_recovery_manager: StateRecoveryManager | None,
) -> None:
async def model_stream(
request: request.Request[Any, Any],
@@ -171,27 +180,29 @@ async def model_stream(
logger.warning("No scope. Sanic may not be running with an ASGI server")
send, recv = _make_send_recv_callbacks(socket)
- await serve_layout(
- Layout(
- ConnectionContext(
- constructor(),
- value=Connection(
- scope=scope,
- location=Location(
- pathname=f"/{path[len(options.url_prefix):]}",
- search=(
- f"?{request.query_string}"
- if request.query_string
- else ""
- ),
- ),
- carrier=_SanicCarrier(request, socket),
- ),
- )
+
+ server = WebsocketServer(send, recv, state_recovery_manager)
+ await server.handle_connection(
+ Connection(
+ scope=scope,
+ location=Location(
+ pathname=f"/{path[len(options.url_prefix):]}",
+ search=(f"?{request.query_string}" if request.query_string else ""),
+ ),
+ carrier=_SanicCarrier(request, socket),
),
- send,
- recv,
+ constructor,
)
+ # await serve_layout(
+ # Layout(
+ # ConnectionContext(
+ # constructor(),
+ # value=,
+ # )
+ # ),
+ # send,
+ # recv,
+ # )
api_blueprint.add_websocket_route(
model_stream,
@@ -209,13 +220,13 @@ def _make_send_recv_callbacks(
socket: WebSocketConnection,
) -> tuple[SendCoroutine, RecvCoroutine]:
async def sock_send(value: Any) -> None:
- await socket.send(json.dumps(value))
+ await socket.send(orjson.dumps(value).decode("utf-8"))
async def sock_recv() -> Any:
data = await socket.recv()
if data is None:
raise Stop()
- return json.loads(data)
+ return orjson.loads(data)
return sock_send, sock_recv
diff --git a/src/py/reactpy/reactpy/core/_life_cycle_hook.py b/src/py/reactpy/reactpy/core/_life_cycle_hook.py
index 88d3386a8..c0e3dfc6f 100644
--- a/src/py/reactpy/reactpy/core/_life_cycle_hook.py
+++ b/src/py/reactpy/reactpy/core/_life_cycle_hook.py
@@ -2,12 +2,16 @@
import logging
from asyncio import Event, Task, create_task, gather
-from typing import Any, Callable, Protocol, TypeVar
+from contextvars import ContextVar, Token
+from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar
from anyio import Semaphore
-from reactpy.core._thread_local import ThreadLocal
from reactpy.core.types import ComponentType, Context, ContextProviderType
+from reactpy.utils import Ref
+
+if TYPE_CHECKING:
+ from reactpy.core.hooks import _CurrentState
T = TypeVar("T")
@@ -18,12 +22,27 @@ async def __call__(self, stop: Event) -> None: ...
logger = logging.getLogger(__name__)
-_HOOK_STATE: ThreadLocal[list[LifeCycleHook]] = ThreadLocal(list)
+_hook_state = ContextVar("_hook_state")
+
+
+def create_hook_state(initial: list | None = None) -> Token[list]:
+ return _hook_state.set(initial or [])
+
+
+def clear_hook_state(token: Token[list]) -> None:
+ hook_stack = _hook_state.get()
+ if hook_stack:
+ logger.warning("clear_hook_state: Hook stack was not empty")
+ _hook_state.reset(token)
+
+
+def get_hook_state() -> list[LifeCycleHook]:
+ return _hook_state.get()
-def current_hook() -> LifeCycleHook:
+def get_current_hook() -> LifeCycleHook:
"""Get the current :class:`LifeCycleHook`"""
- hook_stack = _HOOK_STATE.get()
+ hook_stack = _hook_state.get()
if not hook_stack:
msg = "No life cycle hook is active. Are you rendering in a layout?"
raise RuntimeError(msg)
@@ -117,6 +136,10 @@ async def my_effect(stop_event):
"_scheduled_render",
"_state",
"component",
+ "reconnecting",
+ "client_state",
+ "_updated_states",
+ "_previous_states",
)
component: ComponentType
@@ -124,17 +147,35 @@ async def my_effect(stop_event):
def __init__(
self,
schedule_render: Callable[[], None],
+ reconnecting: Ref,
+ client_state: dict[str, Any],
+ updated_states: dict[str, Any],
+ previous_states: dict[str, Any],
) -> None:
self._context_providers: dict[Context[Any], ContextProviderType[Any]] = {}
self._schedule_render_callback = schedule_render
self._scheduled_render = False
self._rendered_atleast_once = False
self._current_state_index = 0
- self._state: tuple[Any, ...] = ()
+ self._state: list = []
self._effect_funcs: list[EffectFunc] = []
self._effect_tasks: list[Task[None]] = []
self._effect_stops: list[Event] = []
self._render_access = Semaphore(1) # ensure only one render at a time
+ self.reconnecting = reconnecting
+ self.client_state = client_state or {}
+ self._updated_states = updated_states
+ self._previous_states = previous_states
+
+ def add_state_update(self, updated_state: _CurrentState | Ref) -> None:
+ if (
+ updated_state.key
+ and self._previous_states.get(
+ updated_state.key, "__missing_lifecycle_key_value__"
+ )
+ is not updated_state.value
+ ):
+ self._updated_states[updated_state.key] = updated_state.value
def schedule_render(self) -> None:
if self._scheduled_render:
@@ -157,7 +198,7 @@ def use_state(self, function: Callable[[], T]) -> T:
if not self._rendered_atleast_once:
# since we're not initialized yet we're just appending state
result = function()
- self._state += (result,)
+ self._state.append(result)
else:
# once finalized we iterate over each succesively used piece of state
result = self._state[self._current_state_index]
@@ -232,7 +273,7 @@ def set_current(self) -> None:
This method is called by a layout before entering the render method
of this hook's associated component.
"""
- hook_stack = _HOOK_STATE.get()
+ hook_stack = get_hook_state()
if hook_stack:
parent = hook_stack[-1]
self._context_providers.update(parent._context_providers)
@@ -240,5 +281,5 @@ def set_current(self) -> None:
def unset_current(self) -> None:
"""Unset this hook as the active hook in this thread"""
- if _HOOK_STATE.get().pop() is not self:
+ if get_hook_state().pop() is not self:
raise RuntimeError("Hook stack is in an invalid state") # nocov
diff --git a/src/py/reactpy/reactpy/core/_thread_local.py b/src/py/reactpy/reactpy/core/_thread_local.py
deleted file mode 100644
index b3d6a14b0..000000000
--- a/src/py/reactpy/reactpy/core/_thread_local.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from threading import Thread, current_thread
-from typing import Callable, Generic, TypeVar
-from weakref import WeakKeyDictionary
-
-_StateType = TypeVar("_StateType")
-
-
-class ThreadLocal(Generic[_StateType]):
- """Utility for managing per-thread state information"""
-
- def __init__(self, default: Callable[[], _StateType]):
- self._default = default
- self._state: WeakKeyDictionary[Thread, _StateType] = WeakKeyDictionary()
-
- def get(self) -> _StateType:
- thread = current_thread()
- if thread not in self._state:
- state = self._state[thread] = self._default()
- else:
- state = self._state[thread]
- return state
diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py
index e3c6b068d..9d4955546 100644
--- a/src/py/reactpy/reactpy/core/component.py
+++ b/src/py/reactpy/reactpy/core/component.py
@@ -2,7 +2,7 @@
import inspect
from functools import wraps
-from typing import Any, Callable, TypeVar, ParamSpec
+from typing import Any, Callable, ParamSpec, TypeVar
from reactpy.core.types import ComponentType, VdomDict
@@ -11,33 +11,52 @@
def component(
- function: Callable[P, T],
+ function: Callable[P, T] | None = None,
+ *,
+ priority: int = 0,
) -> Callable[P, Component]:
"""A decorator for defining a new component.
Parameters:
- function: The component's :meth:`reactpy.core.proto.ComponentType.render` function.
+ priority: The rendering priority. Lower numbers are higher priority.
"""
- sig = inspect.signature(function)
- if "key" in sig.parameters and sig.parameters["key"].kind in (
- inspect.Parameter.KEYWORD_ONLY,
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- ):
- msg = f"Component render function {function} uses reserved parameter 'key'"
- raise TypeError(msg)
+ def _component(function: Callable[P, T]) -> Callable[P, Component]:
+ sig = inspect.signature(function)
- @wraps(function)
- def constructor(*args: P.args, key: Any | None = None, **kwargs: P.kwargs) -> Component:
- return Component(function, key, args, kwargs, sig)
+ if "key" in sig.parameters and sig.parameters["key"].kind in (
+ inspect.Parameter.KEYWORD_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ msg = f"Component render function {function} uses reserved parameter 'key'"
+ raise TypeError(msg)
- return constructor
+ @wraps(function)
+ def constructor(
+ *args: P.args, key: Any | None = None, **kwargs: P.kwargs
+ ) -> Component:
+ return Component(function, key, args, kwargs, sig, priority)
+
+ return constructor
+
+ if function:
+ return _component(function)
+ return _component
class Component:
"""An object for rending component models."""
- __slots__ = "__weakref__", "_func", "_args", "_kwargs", "_sig", "key", "type"
+ __slots__ = (
+ "__weakref__",
+ "_func",
+ "_args",
+ "_kwargs",
+ "_sig",
+ "key",
+ "type",
+ "priority",
+ )
def __init__(
self,
@@ -46,12 +65,14 @@ def __init__(
args: tuple[Any, ...],
kwargs: dict[str, Any],
sig: inspect.Signature,
+ priority: int = 0,
) -> None:
self.key = key
self.type = function
self._args = args
self._kwargs = kwargs
self._sig = sig
+ self.priority = priority
def render(self) -> ComponentType | VdomDict | str | None:
return self.type(*self._args, **self._kwargs)
diff --git a/src/py/reactpy/reactpy/core/hooks.py b/src/py/reactpy/reactpy/core/hooks.py
index 640cbf14c..72f7c2c4e 100644
--- a/src/py/reactpy/reactpy/core/hooks.py
+++ b/src/py/reactpy/reactpy/core/hooks.py
@@ -1,7 +1,11 @@
from __future__ import annotations
import asyncio
+from functools import lru_cache
+import hashlib
+import sys
from collections.abc import Coroutine, Sequence
+from hashlib import md5
from logging import getLogger
from types import FunctionType
from typing import (
@@ -18,7 +22,8 @@
from typing_extensions import TypeAlias
from reactpy.config import REACTPY_DEBUG_MODE
-from reactpy.core._life_cycle_hook import current_hook
+from reactpy.core._life_cycle_hook import get_current_hook
+from reactpy.core.state_recovery import StateRecoveryFailureError
from reactpy.core.types import Context, Key, State, VdomDict
from reactpy.utils import Ref
@@ -38,6 +43,13 @@
logger = getLogger(__name__)
+
+class ReconnectingOnly(list):
+ """
+ Used to indicate that a hook should only be used during reconnection
+ """
+
+
_Type = TypeVar("_Type")
@@ -49,7 +61,9 @@ def use_state(initial_value: Callable[[], _Type]) -> State[_Type]: ...
def use_state(initial_value: _Type) -> State[_Type]: ...
-def use_state(initial_value: _Type | Callable[[], _Type]) -> State[_Type]:
+def use_state(
+ initial_value: _Type | Callable[[], _Type], *, server_only: bool = False
+) -> State[_Type]:
"""See the full :ref:`Use State` docs for details
Parameters:
@@ -61,23 +75,62 @@ def use_state(initial_value: _Type | Callable[[], _Type]) -> State[_Type]:
Returns:
A tuple containing the current state and a function to update it.
"""
- current_state = _use_const(lambda: _CurrentState(initial_value))
+ if server_only:
+ key = None
+ else:
+ hook = get_current_hook()
+ caller_info = get_caller_info()
+ key = get_state_key(caller_info)
+ if hook.reconnecting.current:
+ try:
+ initial_value = hook.client_state[key]
+ except KeyError as err:
+ raise StateRecoveryFailureError(
+ f"Missing expected key {key} on client"
+ ) from err
+ current_state = _use_const(lambda: _CurrentState(key, initial_value))
return State(current_state.value, current_state.dispatch)
+def get_caller_info():
+ # Get the current stack frame and then the frame above it
+ caller_frame = sys._getframe(2)
+ for i in range(50):
+ render_frame = sys._getframe(4 + i)
+ patch_path = render_frame.f_locals.get("patch_path_for_state")
+ if patch_path is not None:
+ break
+ # Extract the relevant information: file path and line number and hash it
+ return f"{caller_frame.f_code.co_filename} {caller_frame.f_lineno} {patch_path}"
+
+
+__DEBUG_CALLER_INFO_TO_STATE_KEY = {}
+
+
+@lru_cache(8192)
+def get_state_key(caller_info: str) -> str:
+ result = hashlib.sha256(caller_info.encode("utf8")).hexdigest()[:20]
+ if __debug__:
+ __DEBUG_CALLER_INFO_TO_STATE_KEY[result] = caller_info
+ return result
+
+
class _CurrentState(Generic[_Type]):
- __slots__ = "value", "dispatch"
+ __slots__ = "key", "value", "dispatch"
def __init__(
self,
+ key: str | None,
initial_value: _Type | Callable[[], _Type],
) -> None:
+ self.key = key
if callable(initial_value):
self.value = initial_value()
else:
self.value = initial_value
- hook = current_hook()
+ hook = get_current_hook()
+ hook.add_state_update(self)
def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:
if callable(new):
@@ -86,6 +139,7 @@ def dispatch(new: _Type | Callable[[_Type], _Type]) -> None:
next_value = new
if not strictly_equal(next_value, self.value):
self.value = next_value
+ hook.add_state_update(self)
hook.schedule_render()
self.dispatch = dispatch
@@ -131,11 +185,17 @@ def use_effect(
Returns:
If not function is provided, a decorator. Otherwise ``None``.
"""
- hook = current_hook()
-
- dependencies = _try_to_infer_closure_values(function, dependencies)
memoize = use_memo(dependencies=dependencies)
last_clean_callback: Ref[_EffectCleanFunc | None] = use_ref(None)
+ hook = get_current_hook()
+ if hook.reconnecting.current:
+ if not isinstance(dependencies, ReconnectingOnly):
+ return
+ dependencies = None
+ else:
+ if isinstance(dependencies, ReconnectingOnly):
+ return
+ dependencies = _try_to_infer_closure_values(function, dependencies)
def add_effect(function: _EffectApplyFunc) -> None:
if not asyncio.iscoroutinefunction(function):
@@ -204,7 +264,7 @@ def use_debug_value(
if REACTPY_DEBUG_MODE.current and old.current != new:
old.current = new
- logger.debug(f"{current_hook().component} {new}")
+ logger.debug(f"{get_current_hook().component} {new}")
def create_context(default_value: _Type) -> Context[_Type]:
@@ -232,7 +292,7 @@ def use_context(context: Context[_Type]) -> _Type:
See the full :ref:`Use Context` docs for more information.
"""
- hook = current_hook()
+ hook = get_current_hook()
provider = hook.get_context_provider(context)
if provider is None:
@@ -255,14 +315,16 @@ def __init__(
value: _Type,
key: Key | None,
type: Context[_Type],
+ priority: int = -1,
) -> None:
self.children = children
self.key = key
self.type = type
self.value = value
+ self.priority = priority
def render(self) -> VdomDict:
- current_hook().set_context_provider(self)
+ get_current_hook().set_context_provider(self)
return {"tagName": "", "children": self.children}
def __repr__(self) -> str:
@@ -447,7 +509,7 @@ def empty(self) -> bool:
return False
-def use_ref(initial_value: _Type) -> Ref[_Type]:
+def use_ref(initial_value: _Type, server_only: bool = True) -> Ref[_Type]:
"""See the full :ref:`Use State` docs for details
Parameters:
@@ -456,11 +518,24 @@ def use_ref(initial_value: _Type) -> Ref[_Type]:
Returns:
A :class:`Ref` object.
"""
- return _use_const(lambda: Ref(initial_value))
+ if server_only:
+ key = None
+ else:
+ hook = get_current_hook()
+ caller_info = get_caller_info()
+ key = get_state_key(caller_info)
+ if hook.reconnecting.current:
+ try:
+ initial_value = hook.client_state[key]
+ except KeyError as err:
+ raise StateRecoveryFailureError(
+ f"Missing expected key {key} on client"
+ ) from err
+ return _use_const(lambda: Ref(initial_value, key))
def _use_const(function: Callable[[], _Type]) -> _Type:
- return current_hook().use_state(function)
+ return get_current_hook().use_state(function)
def _try_to_infer_closure_values(
diff --git a/src/py/reactpy/reactpy/core/layout.py b/src/py/reactpy/reactpy/core/layout.py
index 70bdbbbff..8db467e30 100644
--- a/src/py/reactpy/reactpy/core/layout.py
+++ b/src/py/reactpy/reactpy/core/layout.py
@@ -1,9 +1,12 @@
from __future__ import annotations
import abc
+import asyncio
+import copy
from asyncio import (
FIRST_COMPLETED,
CancelledError,
+ PriorityQueue,
Queue,
Task,
create_task,
@@ -16,7 +19,9 @@
from logging import getLogger
from typing import (
Any,
+ Awaitable,
Callable,
+ Coroutine,
Generic,
NamedTuple,
NewType,
@@ -34,13 +39,22 @@
REACTPY_CHECK_VDOM_SPEC,
REACTPY_DEBUG_MODE,
)
-from reactpy.core._life_cycle_hook import LifeCycleHook
+from reactpy.core._life_cycle_hook import (
+ LifeCycleHook,
+ clear_hook_state,
+ create_hook_state,
+ get_hook_state,
+)
+from reactpy.core.component import Component
+from reactpy.core.hooks import _ContextProvider
+from reactpy.core.state_recovery import StateRecoverySerializer
from reactpy.core.types import (
ComponentType,
EventHandlerDict,
Key,
LayoutEventMessage,
LayoutUpdateMessage,
+ StateUpdateMessage,
VdomChild,
VdomDict,
VdomJson,
@@ -62,34 +76,64 @@ class Layout:
"_render_tasks_ready",
"_root_life_cycle_state_id",
"_model_states_by_life_cycle_state_id",
+ "reconnecting",
+ "client_state",
+ "_state_recovery_serializer",
+ "_state_var_lock",
+ "_hook_state_token",
+ "_previous_states",
)
if not hasattr(abc.ABC, "__weakref__"): # nocov
__slots__ += ("__weakref__",)
- def __init__(self, root: ComponentType) -> None:
+ def __init__(
+ self,
+ root: ComponentType,
+ ) -> None:
super().__init__()
- if not isinstance(root, ComponentType):
- msg = f"Expected a ComponentType, not {type(root)!r}."
- raise TypeError(msg)
+ # slow
+ # if not isinstance(root, ComponentType):
+ # msg = f"Expected a ComponentType, not {type(root)!r}."
+ # raise TypeError(msg)
self.root = root
+ self.reconnecting = Ref(False)
+ self._state_recovery_serializer = None
+ self.client_state = {}
+ self._previous_states = {}
+
+ def set_recovery_serializer(self, serializer: StateRecoverySerializer) -> None:
+ self._state_recovery_serializer = serializer
async def __aenter__(self) -> Layout:
+ return await self.start()
+
+ async def start(self) -> Layout:
+ self._hook_state_token = create_hook_state()
+
# create attributes here to avoid access before entering context manager
self._event_handlers: EventHandlerDict = {}
self._render_tasks: set[Task[LayoutUpdateMessage]] = set()
self._render_tasks_ready: Semaphore = Semaphore(0)
self._rendering_queue: _ThreadSafeQueue[_LifeCycleStateId] = _ThreadSafeQueue()
- root_model_state = _new_root_model_state(self.root, self._schedule_render_task)
+ root_model_state = _new_root_model_state(
+ self.root,
+ self._schedule_render_task,
+ self.reconnecting,
+ self.client_state,
+ self._previous_states,
+ )
self._root_life_cycle_state_id = root_id = root_model_state.life_cycle_state.id
self._model_states_by_life_cycle_state_id = {root_id: root_model_state}
- self._schedule_render_task(root_id)
return self
async def __aexit__(self, *exc: Any) -> None:
+ return await self.finish()
+
+ async def finish(self) -> None:
root_csid = self._root_life_cycle_state_id
root_model_state = self._model_states_by_life_cycle_state_id[root_csid]
@@ -108,6 +152,14 @@ async def __aexit__(self, *exc: Any) -> None:
del self._root_life_cycle_state_id
del self._model_states_by_life_cycle_state_id
+ clear_hook_state(self._hook_state_token)
+
+ def start_rendering(self) -> None:
+ self._schedule_render_task(self._root_life_cycle_state_id)
+
+ def start_rendering_for_reconnect(self) -> None:
+ self._rendering_queue.put(self._root_life_cycle_state_id)
+
async def deliver(self, event: LayoutEventMessage) -> None:
"""Dispatch an event to the targeted handler"""
# It is possible for an element in the frontend to produce an event
@@ -122,7 +174,7 @@ async def deliver(self, event: LayoutEventMessage) -> None:
except Exception:
logger.exception(f"Failed to execute event handler {handler}")
else:
- logger.info(
+ logger.warning(
f"Ignored event - handler {event['target']!r} "
"does not exist or its component unmounted"
)
@@ -133,6 +185,32 @@ async def render(self) -> LayoutUpdateMessage:
else: # nocov
return await self._serial_render()
+ async def render_until_queue_empty(self) -> None:
+ model_state_id = await self._rendering_queue.get()
+ while True:
+ try:
+ model_state = self._model_states_by_life_cycle_state_id[model_state_id]
+ except KeyError:
+ logger.debug(
+ "Did not render component with model state ID "
+ f"{model_state_id!r} - component already unmounted"
+ )
+ else:
+ await self._create_layout_update(model_state, get_hook_state())
+ # this might seem counterintuitive. What's happening is that events can get kicked off
+ # and currently there's no (obvious) visibility on if we're waiting for them to finish
+ # so this will wait up to 0.15 * 5 = 750 ms to see if any renders come in before
+ # declaring it done. In the future, it would be better to just track the pending events
+ for _ in range(5):
+ try:
+ model_state_id = await self._rendering_queue.get_nowait()
+ except asyncio.QueueEmpty:
+ await asyncio.sleep(0.15) # make sure
+ else:
+ break
+ else:
+ return
+
async def _serial_render(self) -> LayoutUpdateMessage: # nocov
"""Await the next available render. This will block until a component is updated"""
while True:
@@ -145,7 +223,7 @@ async def _serial_render(self) -> LayoutUpdateMessage: # nocov
f"{model_state_id!r} - component already unmounted"
)
else:
- return await self._create_layout_update(model_state)
+ return await self._create_layout_update(model_state, get_hook_state())
async def _concurrent_render(self) -> LayoutUpdateMessage:
"""Await the next available render. This will block until a component is updated"""
@@ -156,8 +234,9 @@ async def _concurrent_render(self) -> LayoutUpdateMessage:
return update_task.result()
async def _create_layout_update(
- self, old_state: _ModelState
+ self, old_state: _ModelState, incoming_hook_state: list
) -> LayoutUpdateMessage:
+ token = create_hook_state(copy.copy(incoming_hook_state))
new_state = _copy_component_model_state(old_state)
component = new_state.life_cycle_state.component
@@ -167,11 +246,21 @@ async def _create_layout_update(
if REACTPY_CHECK_VDOM_SPEC.current:
validate_vdom_json(new_state.model.current)
- return {
- "type": "layout-update",
- "path": new_state.patch_path,
- "model": new_state.model.current,
- }
+ updated_states = new_state.life_cycle_state.hook._updated_states
+ state_vars = (
+ (await self._state_recovery_serializer.serialize_state_vars(updated_states))
+ if self._state_recovery_serializer
+ else {}
+ )
+ self._previous_states.update(updated_states)
+ updated_states.clear()
+ clear_hook_state(token)
+ return LayoutUpdateMessage(
+ type="layout-update",
+ path=new_state.patch_path,
+ model=new_state.model.current,
+ state_vars=state_vars,
+ )
async def _render_component(
self,
@@ -188,6 +277,7 @@ async def _render_component(
await life_cycle_hook.affect_component_will_render(component)
exit_stack.push_async_callback(life_cycle_hook.affect_layout_did_render)
try:
+ patch_path_for_state = new_state.patch_path # type: ignore # noqa
raw_model = component.render()
# wrap the model in a fragment (i.e. tagName="") to ensure components have
# a separate node in the model state tree. This could be removed if this
@@ -278,7 +368,11 @@ def _render_model_attributes(
if event in old_state.targets_by_event:
target = old_state.targets_by_event[event]
else:
- target = uuid4().hex if handler.target is None else handler.target
+ target = (
+ new_state.patch_path + event
+ if handler.target is None
+ else handler.target
+ )
new_state.targets_by_event[event] = target
self._event_handlers[target] = handler
model_event_handlers[event] = {
@@ -299,7 +393,11 @@ def _render_model_event_handlers_without_old_state(
model_event_handlers = new_state.model.current["eventHandlers"] = {}
for event, handler in handlers_by_event.items():
- target = uuid4().hex if handler.target is None else handler.target
+ target = (
+ new_state.patch_path + event
+ if handler.target is None
+ else handler.target
+ )
new_state.targets_by_event[event] = target
self._event_handlers[target] = handler
model_event_handlers[event] = {
@@ -385,6 +483,8 @@ async def _render_model_children(
key,
child,
self._schedule_render_task,
+ self.reconnecting,
+ self.client_state,
)
elif old_child_state.is_component_state and (
old_child_state.life_cycle_state.component.type != child.type
@@ -397,6 +497,8 @@ async def _render_model_children(
key,
child,
self._schedule_render_task,
+ self.reconnecting,
+ self.client_state,
)
else:
new_child_state = _update_component_model_state(
@@ -405,6 +507,8 @@ async def _render_model_children(
index,
child,
self._schedule_render_task,
+ self.reconnecting,
+ self.client_state,
)
await self._render_component(
exit_stack, old_child_state, new_child_state, child
@@ -439,7 +543,13 @@ async def _render_model_children_without_old_state(
new_state.children_by_key[key] = child_state
elif child_type is _COMPONENT_TYPE:
child_state = _make_component_model_state(
- new_state, index, key, child, self._schedule_render_task
+ new_state,
+ index,
+ key,
+ child,
+ self._schedule_render_task,
+ self.reconnecting,
+ self.client_state,
)
await self._render_component(exit_stack, None, child_state, child)
else:
@@ -455,14 +565,19 @@ async def _unmount_model_states(self, old_states: list[_ModelState]) -> None:
if model_state.is_component_state:
life_cycle_state = model_state.life_cycle_state
- del self._model_states_by_life_cycle_state_id[life_cycle_state.id]
- await life_cycle_state.hook.affect_component_will_unmount()
+ try:
+ del self._model_states_by_life_cycle_state_id[life_cycle_state.id]
+ await life_cycle_state.hook.affect_component_will_unmount()
+ except KeyError:
+ pass # sideeffect of reusing model states
to_unmount.extend(model_state.children_by_key.values())
- def _schedule_render_task(self, lcs_id: _LifeCycleStateId) -> None:
+ def _schedule_render_task(
+ self, lcs_id: _LifeCycleStateId, priority: int = 0
+ ) -> None:
if not REACTPY_ASYNC_RENDERING.current:
- self._rendering_queue.put(lcs_id)
+ self._rendering_queue.put(lcs_id, priority)
return None
try:
model_state = self._model_states_by_life_cycle_state_id[lcs_id]
@@ -480,7 +595,11 @@ def __repr__(self) -> str:
def _new_root_model_state(
- component: ComponentType, schedule_render: Callable[[_LifeCycleStateId], None]
+ component: ComponentType,
+ schedule_render: Callable[[_LifeCycleStateId], None],
+ reconnecting: bool,
+ client_state: dict[str, Any],
+ previous_states: dict[str, Any],
) -> _ModelState:
return _ModelState(
parent=None,
@@ -490,7 +609,9 @@ def _new_root_model_state(
patch_path="",
children_by_key={},
targets_by_event={},
- life_cycle_state=_make_life_cycle_state(component, schedule_render),
+ life_cycle_state=_make_life_cycle_state(
+ component, schedule_render, reconnecting, client_state, {}, previous_states
+ ),
)
@@ -499,8 +620,11 @@ def _make_component_model_state(
index: int,
key: Any,
component: ComponentType,
- schedule_render: Callable[[_LifeCycleStateId], None],
+ schedule_render: Callable[[_LifeCycleStateId, int], None],
+ reconnecting: bool,
+ client_state: dict[str, Any],
) -> _ModelState:
+ hook = (parent.life_cycle_state or parent.parent_life_cycle_state).hook
return _ModelState(
parent=parent,
index=index,
@@ -509,7 +633,14 @@ def _make_component_model_state(
patch_path=f"{parent.patch_path}/children/{index}",
children_by_key={},
targets_by_event={},
- life_cycle_state=_make_life_cycle_state(component, schedule_render),
+ life_cycle_state=_make_life_cycle_state(
+ component,
+ schedule_render,
+ reconnecting,
+ client_state,
+ hook._updated_states,
+ hook._previous_states,
+ ),
)
@@ -537,8 +668,11 @@ def _update_component_model_state(
new_parent: _ModelState,
new_index: int,
new_component: ComponentType,
- schedule_render: Callable[[_LifeCycleStateId], None],
+ schedule_render: Callable[[_LifeCycleStateId, int], None],
+ reconnecting: bool,
+ client_state: dict[str, Any],
) -> _ModelState:
+ hook = (new_parent.life_cycle_state or new_parent.parent_life_cycle_state).hook
return _ModelState(
parent=new_parent,
index=new_index,
@@ -550,7 +684,14 @@ def _update_component_model_state(
life_cycle_state=(
_update_life_cycle_state(old_model_state.life_cycle_state, new_component)
if old_model_state.is_component_state
- else _make_life_cycle_state(new_component, schedule_render)
+ else _make_life_cycle_state(
+ new_component,
+ schedule_render,
+ reconnecting,
+ client_state,
+ hook._updated_states,
+ hook._previous_states,
+ )
),
)
@@ -568,6 +709,9 @@ def _make_element_model_state(
patch_path=f"{parent.patch_path}/children/{index}",
children_by_key={},
targets_by_event={},
+ parent_life_cycle_state=(
+ parent.life_cycle_state or parent.parent_life_cycle_state
+ ),
)
@@ -584,6 +728,9 @@ def _update_element_model_state(
patch_path=old_model_state.patch_path,
children_by_key={},
targets_by_event={},
+ parent_life_cycle_state=(
+ new_parent.life_cycle_state or new_parent.parent_life_cycle_state
+ ),
)
@@ -598,6 +745,7 @@ class _ModelState:
"index",
"key",
"life_cycle_state",
+ "parent_life_cycle_state",
"model",
"patch_path",
"targets_by_event",
@@ -613,6 +761,7 @@ def __init__(
children_by_key: dict[Key, _ModelState],
targets_by_event: dict[str, str],
life_cycle_state: _LifeCycleState | None = None,
+ parent_life_cycle_state: _LifeCycleState | None = None,
):
self.index = index
"""The index of the element amongst its siblings"""
@@ -639,13 +788,14 @@ def __init__(
self._parent_ref = weakref(parent)
"""The parent model state"""
- if life_cycle_state is not None:
- self.life_cycle_state = life_cycle_state
- """The state for the element's component (if it has one)"""
+ self.life_cycle_state = life_cycle_state
+ """The state for the element's component (if it has one)"""
+
+ self.parent_life_cycle_state = parent_life_cycle_state
@property
def is_component_state(self) -> bool:
- return hasattr(self, "life_cycle_state")
+ return self.life_cycle_state is not None
@property
def parent(self) -> _ModelState:
@@ -663,12 +813,22 @@ def __repr__(self) -> str: # nocov
def _make_life_cycle_state(
component: ComponentType,
- schedule_render: Callable[[_LifeCycleStateId], None],
+ schedule_render: Callable[[_LifeCycleStateId, int], None],
+ reconnecting: bool,
+ client_state: dict[str, Any],
+ updated_states: dict[str, Any],
+ previous_states: dict[str, Any],
) -> _LifeCycleState:
life_cycle_state_id = _LifeCycleStateId(uuid4().hex)
return _LifeCycleState(
life_cycle_state_id,
- LifeCycleHook(lambda: schedule_render(life_cycle_state_id)),
+ LifeCycleHook(
+ lambda: schedule_render(life_cycle_state_id, component.priority),
+ reconnecting=reconnecting,
+ client_state=client_state,
+ updated_states=updated_states,
+ previous_states=previous_states,
+ ),
component,
)
@@ -707,16 +867,21 @@ class _LifeCycleState(NamedTuple):
class _ThreadSafeQueue(Generic[_Type]):
def __init__(self) -> None:
self._loop = get_running_loop()
- self._queue: Queue[_Type] = Queue()
+ self._queue: PriorityQueue[_Type] = PriorityQueue()
self._pending: set[_Type] = set()
- def put(self, value: _Type) -> None:
+ def put(self, value: _Type, priority: int = 0) -> None:
if value not in self._pending:
self._pending.add(value)
- self._loop.call_soon_threadsafe(self._queue.put_nowait, value)
+ self._loop.call_soon_threadsafe(self._queue.put_nowait, (priority, value))
async def get(self) -> _Type:
- value = await self._queue.get()
+ priority, value = await self._queue.get()
+ self._pending.remove(value)
+ return value
+
+ async def get_nowait(self) -> _Type:
+ priority, value = self._queue.get_nowait()
self._pending.remove(value)
return value
@@ -729,7 +894,7 @@ def _get_children_info(children: list[VdomChild]) -> Sequence[_ChildInfo]:
elif isinstance(child, dict):
child_type = _DICT_TYPE
key = child.get("key")
- elif isinstance(child, ComponentType):
+ elif isinstance(child, (Component, _ContextProvider)):
child_type = _COMPONENT_TYPE
key = child.key
else:
diff --git a/src/py/reactpy/reactpy/core/serve.py b/src/py/reactpy/reactpy/core/serve.py
index 3a540af59..b7832b565 100644
--- a/src/py/reactpy/reactpy/core/serve.py
+++ b/src/py/reactpy/reactpy/core/serve.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import random
+import string
from collections.abc import Awaitable
from logging import getLogger
from typing import Callable
@@ -8,16 +10,39 @@
from anyio import create_task_group
from anyio.abc import TaskGroup
+from reactpy.backend.hooks import ConnectionContext
+from reactpy.backend.types import Connection
from reactpy.config import REACTPY_DEBUG_MODE
-from reactpy.core.types import LayoutEventMessage, LayoutType, LayoutUpdateMessage
+from reactpy.core._life_cycle_hook import clear_hook_state, create_hook_state
+from reactpy.core.layout import Layout
+from reactpy.core.state_recovery import StateRecoveryFailureError, StateRecoveryManager
+from reactpy.core.types import (
+ ClientStateMessage,
+ IsReadyMessage,
+ LayoutEventMessage,
+ LayoutType,
+ LayoutUpdateMessage,
+ ReconnectingCheckMessage,
+ RootComponentConstructor,
+)
logger = getLogger(__name__)
-SendCoroutine = Callable[[LayoutUpdateMessage], Awaitable[None]]
+SendCoroutine = Callable[
+ [
+ LayoutUpdateMessage
+ | ReconnectingCheckMessage
+ | IsReadyMessage
+ | ClientStateMessage
+ ],
+ Awaitable[None],
+]
"""Send model patches given by a dispatcher"""
-RecvCoroutine = Callable[[], Awaitable[LayoutEventMessage]]
+RecvCoroutine = Callable[
+ [], Awaitable[LayoutEventMessage | ReconnectingCheckMessage | ClientStateMessage]
+]
"""Called by a dispatcher to return a :class:`reactpy.core.layout.LayoutEventMessage`
The event will then trigger an :class:`reactpy.core.proto.EventHandlerType` in a layout.
@@ -40,44 +65,146 @@ async def serve_layout(
recv: RecvCoroutine,
) -> None:
"""Run a dispatch loop for a single view instance"""
- async with layout:
- try:
- async with create_task_group() as task_group:
- task_group.start_soon(_single_outgoing_loop, layout, send)
- task_group.start_soon(_single_incoming_loop, task_group, layout, recv)
- except Stop: # nocov
- warn(
- "The Stop exception is deprecated and will be removed in a future version",
- UserWarning,
- stacklevel=1,
- )
- logger.info(f"Stopped serving {layout}")
+
+ try:
+ async with create_task_group() as task_group:
+ task_group.start_soon(_single_outgoing_loop, layout, send)
+ task_group.start_soon(_single_incoming_loop, task_group, layout, recv, send)
+ except Stop: # nocov
+ warn(
+ "The Stop exception is deprecated and will be removed in a future version",
+ UserWarning,
+ stacklevel=1,
+ )
+ logger.info(f"Stopped serving {layout}")
async def _single_outgoing_loop(
layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage], send: SendCoroutine
) -> None:
while True:
- update = await layout.render()
+ token = create_hook_state()
try:
- await send(update)
- except Exception: # nocov
- if not REACTPY_DEBUG_MODE.current:
- msg = (
- "Failed to send update. More info may be available "
- "if you enabling debug mode by setting "
- "`reactpy.config.REACTPY_DEBUG_MODE.current = True`."
- )
- logger.error(msg)
- raise
+ update = await layout.render()
+ try:
+ await send(update)
+ except Exception: # nocov
+ if not REACTPY_DEBUG_MODE.current:
+ msg = (
+ "Failed to send update. More info may be available "
+ "if you enabling debug mode by setting "
+ "`reactpy.config.REACTPY_DEBUG_MODE.current = True`."
+ )
+ logger.error(msg)
+ raise
+ finally:
+ clear_hook_state(token)
async def _single_incoming_loop(
task_group: TaskGroup,
layout: LayoutType[LayoutUpdateMessage, LayoutEventMessage],
recv: RecvCoroutine,
+ send: SendCoroutine,
) -> None:
while True:
# We need to fire and forget here so that we avoid waiting on the completion
# of this event handler before receiving and running the next one.
task_group.start_soon(layout.deliver, await recv())
+
+
+class WebsocketServer:
+ def __init__(
+ self,
+ send: SendCoroutine,
+ recv: RecvCoroutine,
+ state_recovery_manager: StateRecoveryManager | None = None,
+ ) -> None:
+ self._send = send
+ self._recv = recv
+ self._state_recovery_manager = state_recovery_manager
+ self._salt: str | None = None
+
+ async def handle_connection(
+ self, connection: Connection, constructor: RootComponentConstructor
+ ):
+ layout = Layout(
+ ConnectionContext(
+ constructor(),
+ value=connection,
+ ),
+ )
+ async with layout:
+ await self._handshake(layout)
+ # salt may be set to client's old salt during handshake
+ if self._state_recovery_manager:
+ layout.set_recovery_serializer(
+ self._state_recovery_manager.create_serializer(self._salt)
+ )
+ await serve_layout(
+ layout,
+ self._send,
+ self._recv,
+ )
+
+ async def _handshake(self, layout: Layout) -> None:
+ await self._send(ReconnectingCheckMessage(type="reconnecting-check"))
+ result = await self._recv()
+ self._salt = "".join(random.choices(string.ascii_letters + string.digits, k=8))
+ if result["type"] == "reconnecting-check":
+ if result["value"] == "yes":
+ if self._state_recovery_manager is None:
+ logger.warning(
+ "Reconnection detected, but no state recovery manager provided"
+ )
+ layout.start_rendering()
+ else:
+ logger.info("Handshake: Doing state rebuild for reconnection")
+ self._salt = await self._do_state_rebuild_for_reconnection(layout)
+ logger.info("Handshake: Completed doing state rebuild")
+ else:
+ logger.info("Handshake: new connection")
+ layout.start_rendering()
+ else:
+ logger.warning(
+ f"Unexpected type when expecting reconnecting-check: {result['type']}"
+ )
+ await self._indicate_ready(),
+
+ async def _indicate_ready(self) -> None:
+ await self._send(IsReadyMessage(type="is-ready", salt=self._salt))
+
+ async def _do_state_rebuild_for_reconnection(self, layout: Layout) -> str:
+ salt = self._salt
+ await self._send(ClientStateMessage(type="client-state"))
+ client_state_msg = await self._recv()
+ if client_state_msg["type"] != "client-state":
+ logger.warning(
+ f"Unexpected type when expecting client-state: {client_state_msg['type']}"
+ )
+ return
+ state_vars = client_state_msg["value"]
+ try:
+ serializer = self._state_recovery_manager.create_serializer(
+ client_state_msg["salt"]
+ )
+ client_state = serializer.deserialize_client_state(state_vars)
+ layout.reconnecting.set_current(True)
+ layout.client_state = client_state
+ except StateRecoveryFailureError:
+ logger.exception("State recovery failed")
+ layout.reconnecting.set_current(False)
+ layout.client_state = {}
+ else:
+ salt = client_state_msg["salt"]
+ try:
+ layout.start_rendering_for_reconnect()
+ await layout.render_until_queue_empty()
+ except StateRecoveryFailureError:
+ logger.warning("Client state non-recoverable. Starting fresh")
+ await layout.finish()
+ await layout.start()
+ layout.start_rendering()
+ layout.reconnecting.set_current(False)
+ layout.client_state = {}
+ return salt
diff --git a/src/py/reactpy/reactpy/core/state_recovery.py b/src/py/reactpy/reactpy/core/state_recovery.py
new file mode 100644
index 000000000..38be33786
--- /dev/null
+++ b/src/py/reactpy/reactpy/core/state_recovery.py
@@ -0,0 +1,276 @@
+import asyncio
+import base64
+import datetime
+import hashlib
+import time
+from collections.abc import Iterable
+from decimal import Decimal
+from logging import getLogger
+from pathlib import Path
+from typing import Any, Callable
+from uuid import UUID
+
+import orjson
+import pyotp
+from more_itertools import chunked
+
+logger = getLogger(__name__)
+
+
+class StateRecoveryFailureError(Exception):
+ """
+ Raised when state recovery fails.
+ """
+
+
+class StateRecoveryManager:
+ def __init__(
+ self,
+ serializable_types: Iterable[type],
+ pepper: str,
+ otp_key: str | None = None,
+ otp_interval: int = (4 * 60 * 60),
+ otp_digits: int = 10, # 10 is the max allowed
+ otp_max_age: int = (48 * 60 * 60),
+ # OTP code is actually three codes, in the past and future concatenated
+ otp_mixer: float = (365 * 24 * 60 * 60 * 3),
+ max_num_state_objects: int = 512,
+ max_object_length: int = 40000,
+ default_serializer: Callable[[Any], bytes] | None = None,
+ deserializer_map: dict[type, Callable[[Any], Any]] | None = None,
+ ) -> None:
+ self._pepper = pepper
+ self._max_num_state_objects = max_num_state_objects
+ self._max_object_length = max_object_length
+ self._otp_key = base64.b32encode(
+ (otp_key or self._discover_otp_key()).encode("utf-8")
+ )
+ self._totp = pyotp.TOTP(self._otp_key, digits=otp_digits, interval=otp_interval)
+ self._otp_max_age = otp_max_age
+ self._default_serializer = default_serializer
+ self._deserializer_map = deserializer_map or {}
+ self._otp_mixer = otp_mixer
+
+ self._map_objects_to_ids(
+ [
+ *list(serializable_types),
+ Decimal,
+ datetime.datetime,
+ datetime.date,
+ datetime.time,
+ ]
+ )
+
+ def _map_objects_to_ids(self, serializable_types: Iterable[type]) -> dict:
+ self._object_to_type_id = {}
+ self._type_id_to_object = {}
+ for idx, typ in enumerate(
+ (None, bool, str, int, float, list, tuple, UUID, *serializable_types)
+ ):
+ idx_as_bytes = str(idx).encode("utf-8")
+ self._object_to_type_id[typ] = idx_as_bytes
+ self._type_id_to_object[idx_as_bytes] = typ
+
+ def _discover_otp_key(self) -> str:
+ """
+ Generate an OTP key by looking at the parent directory of where
+ ReactPy is installed and taking down the names and creation times
+ of everything in there.
+ """
+ hasher = hashlib.sha256()
+ parent_dir_of_root = Path(__file__).parent.parent.parent
+ for thing in parent_dir_of_root.iterdir():
+ hasher.update((thing.name + str(thing.stat().st_ctime)).encode("utf-8"))
+ return hasher.hexdigest()
+
+ def create_serializer(
+ self, salt: str, target_time: float | None = None
+ ) -> "StateRecoverySerializer":
+ return StateRecoverySerializer(
+ totp=self._totp,
+ target_time=target_time,
+ otp_max_age=self._otp_max_age,
+ otp_mixer=self._otp_mixer,
+ pepper=self._pepper,
+ salt=salt,
+ object_to_type_id=self._object_to_type_id,
+ type_id_to_object=self._type_id_to_object,
+ max_object_length=self._max_object_length,
+ max_num_state_objects=self._max_num_state_objects,
+ default_serializer=self._default_serializer,
+ deserializer_map=self._deserializer_map,
+ )
+
+
+class StateRecoverySerializer:
+
+ def __init__(
+ self,
+ totp: pyotp.TOTP,
+ target_time: float | None,
+ otp_max_age: int,
+ otp_mixer: float,
+ pepper: str,
+ salt: str,
+ object_to_type_id: dict[Any, bytes],
+ type_id_to_object: dict[bytes, Any],
+ max_object_length: int,
+ max_num_state_objects: int,
+ default_serializer: Callable[[Any], bytes] | None = None,
+ deserializer_map: dict[type, Callable[[Any], Any]] | None = None,
+ ) -> None:
+ self._totp = totp
+ self._otp_mixer = otp_mixer
+ target_time = target_time or time.time()
+ self._target_time = target_time
+ otp_code = self._get_otp_code(target_time)
+ self._otp_max_age = otp_max_age
+ self._otp_code = otp_code.encode("utf-8")
+ self._pepper = pepper.encode("utf-8")
+ self._salt = salt.encode("utf-8")
+ self._object_to_type_id = object_to_type_id
+ self._type_id_to_object = type_id_to_object
+ self._max_object_length = max_object_length
+ self._max_num_state_objects = max_num_state_objects
+ self._default_serializer = default_serializer
+ self._deserializer_map = deserializer_map or {}
+
+ def _get_otp_code(self, target_time: float) -> str:
+ at = self._totp.at
+ return f"{at(target_time)}{at(target_time - self._otp_mixer)}{at(target_time + self._otp_mixer)}"
+
+ async def serialize_state_vars(
+ self, state_vars: dict[str, Any]
+ ) -> dict[str, tuple[str, str, str]]:
+ if len(state_vars) > self._max_num_state_objects:
+ logger.warning(
+ f"State is too large ({len(state_vars)}). State will not be sent"
+ )
+ return {}
+ result = {}
+ for chunk in chunked(state_vars.items(), 50):
+ for key, value in chunk:
+ result[key] = self._serialize(key, value)
+ await asyncio.sleep(0) # relinquish CPU
+ return result
+
+ def _serialize(self, key: str, obj: object) -> tuple[str, str, str]:
+ type_id = b"1" # bool
+ if obj is None:
+ return "0", "", ""
+ match obj:
+ case True:
+ result = b"true"
+ case False:
+ result = b"false"
+ case _:
+ obj_type = type(obj)
+ if obj_type in (list, tuple):
+ if len(obj) != 0:
+ obj_type = type(obj[0])
+ for t in obj_type.__mro__:
+ type_id = self._object_to_type_id.get(t)
+ if type_id:
+ break
+ else:
+ raise ValueError(
+ f"Objects of type {obj_type} was not part of serializable_types"
+ )
+ result = self._serialize_object(obj)
+ if len(result) > self._max_object_length:
+ raise ValueError(
+ f"Serialized object {obj} is too long (length: {len(result)})"
+ )
+ signature = self._sign_serialization(key, type_id, result)
+ return (
+ type_id.decode("utf-8"),
+ base64.urlsafe_b64encode(result).decode("utf-8"),
+ signature,
+ )
+
+ def deserialize_client_state(
+ self, state_vars: dict[str, tuple[str, str, str]]
+ ) -> None:
+ return {
+ key: self._deserialize(key, type_id.encode("utf-8"), data, signature)
+ for key, (type_id, data, signature) in state_vars.items()
+ }
+
+ def _deserialize(
+ self, key: str, type_id: bytes, data: bytes, signature: str
+ ) -> Any:
+ if type_id == b"0":
+ return None
+ try:
+ typ = self._type_id_to_object[type_id]
+ except KeyError as err:
+ raise StateRecoveryFailureError(f"Unknown type id {type_id}") from err
+
+ result = base64.urlsafe_b64decode(data)
+ expected_signature = self._sign_serialization(key, type_id, result)
+ if expected_signature != signature:
+ if not self._try_future_code(key, type_id, result, signature):
+ if not self._try_older_codes_and_see_if_one_checks_out(
+ key, type_id, result, signature
+ ):
+ raise StateRecoveryFailureError(
+ f"Signature mismatch for type id {type_id}"
+ )
+ return self._deserialize_object(typ, result)
+
+ def _try_future_code(
+ self, key: str, type_id: bytes, data: bytes, signature: str
+ ) -> bool:
+ future_time = self._target_time + self._totp.interval
+ otp_code = self._get_otp_code(future_time).encode("utf-8")
+ return self._sign_serialization(key, type_id, data, otp_code) == signature
+
+ def _try_older_codes_and_see_if_one_checks_out(
+ self, key: str, type_id: bytes, data: bytes, signature: str
+ ) -> bool:
+ past_time = self._target_time
+ for _ in range(100):
+ past_time -= self._totp.interval
+ otp_code = self._get_otp_code(past_time).encode("utf-8")
+ if self._sign_serialization(key, type_id, data, otp_code) == signature:
+ return True
+ if past_time < self._target_time - self._otp_max_age:
+ return False
+ raise RuntimeError("Too many iterations: _try_older_codes_and_see_if_one_checks_out")
+
+ def _sign_serialization(
+ self, key: str, type_id: bytes, data: bytes, otp_code: bytes | None = None
+ ) -> str:
+ hasher = hashlib.sha256()
+ hasher.update(type_id)
+ hasher.update(data)
+ hasher.update(self._pepper)
+ hasher.update(otp_code or self._otp_code)
+ hasher.update(self._salt)
+ hasher.update(key.encode("utf-8"))
+ return hasher.hexdigest()
+
+ def _serialize_object(self, obj: Any) -> bytes:
+ return orjson.dumps(obj, default=self._default_serializer)
+
+ def _do_deserialize(
+ self, typ: type, result: Any, custom_deserializer: Callable | None
+ ) -> Any:
+ if custom_deserializer:
+ return custom_deserializer(result)
+ if isinstance(result, str):
+ return typ(result)
+ if isinstance(result, dict):
+ return typ(**result)
+ return result
+
+ def _deserialize_object(self, typ: Any, data: bytes) -> Any:
+ if typ is None and not data:
+ return None
+ result = orjson.loads(data)
+ custom_deserializer = self._deserializer_map.get(typ)
+ if type(result) in (list, tuple):
+ return [
+ self._do_deserialize(typ, item, custom_deserializer) for item in result
+ ]
+ return self._do_deserialize(typ, result, custom_deserializer)
diff --git a/src/py/reactpy/reactpy/core/types.py b/src/py/reactpy/reactpy/core/types.py
index b451be30a..a4be74f61 100644
--- a/src/py/reactpy/reactpy/core/types.py
+++ b/src/py/reactpy/reactpy/core/types.py
@@ -57,6 +57,7 @@ class ComponentType(Protocol):
This is used to see if two component instances share the same definition.
"""
+ priority: int
def render(self) -> VdomDict | ComponentType | str | None:
"""Render the component's view model."""
@@ -213,6 +214,42 @@ class LayoutUpdateMessage(TypedDict):
"""JSON Pointer path to the model element being updated"""
model: VdomJson
"""The model to assign at the given JSON Pointer path"""
+ state_vars: dict[str, Any]
+
+
+class StateUpdateMessage(TypedDict):
+ """A message describing an update to state variables"""
+
+ type: Literal["state-update"]
+ """The type of message"""
+ state_vars: dict[str, Any]
+
+
+class ReconnectingCheckMessage(TypedDict):
+ """A message describing an update to a layout"""
+
+ type: Literal["reconnecting-check"]
+ """The type of message"""
+ value: Literal["yes", "no"]
+
+
+class ClientStateMessage(TypedDict):
+ """A message requesting the current state of the client"""
+
+ type: Literal["client-state"]
+ """The type of message"""
+ value: dict[str, Any]
+ """The client state"""
+ salt: str
+ """The salt provided to the user"""
+
+
+class IsReadyMessage(TypedDict):
+ """Indicate server is ready for client events"""
+
+ type: Literal["is-ready"]
+
+ salt: str
class LayoutEventMessage(TypedDict):
diff --git a/src/py/reactpy/reactpy/testing/common.py b/src/py/reactpy/reactpy/testing/common.py
index c1eb18ba5..84f3243ae 100644
--- a/src/py/reactpy/reactpy/testing/common.py
+++ b/src/py/reactpy/reactpy/testing/common.py
@@ -13,7 +13,7 @@
from typing_extensions import ParamSpec
from reactpy.config import REACTPY_TESTING_DEFAULT_TIMEOUT, REACTPY_WEB_MODULES_DIR
-from reactpy.core._life_cycle_hook import LifeCycleHook, current_hook
+from reactpy.core._life_cycle_hook import LifeCycleHook, get_current_hook
from reactpy.core.events import EventHandler, to_event_handler_function
@@ -143,7 +143,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
if self is None:
raise RuntimeError("Hook catcher has been garbage collected")
- hook = current_hook()
+ hook = get_current_hook()
if self.index_by_kwarg is not None:
self.index[kwargs[self.index_by_kwarg]] = hook
self.latest = hook
diff --git a/src/py/reactpy/reactpy/utils.py b/src/py/reactpy/reactpy/utils.py
index 5624846a4..06599c445 100644
--- a/src/py/reactpy/reactpy/utils.py
+++ b/src/py/reactpy/reactpy/utils.py
@@ -27,12 +27,26 @@ class Ref(Generic[_RefValue]):
You can compare the contents for two ``Ref`` objects using the ``==`` operator.
"""
- __slots__ = ("current",)
+ __slots__ = ("current", "key", "_hook")
+
+ def __init__(
+ self, initial_value: _RefValue = _UNDEFINED, key: str | None = None
+ ) -> None:
+ from reactpy.core._life_cycle_hook import get_current_hook
- def __init__(self, initial_value: _RefValue = _UNDEFINED) -> None:
if initial_value is not _UNDEFINED:
self.current = initial_value
"""The present value"""
+ self.key = key
+ self._hook = None
+ if key:
+ hook = get_current_hook()
+ hook.add_state_update(self)
+ self._hook = hook
+
+ @property
+ def value(self) -> _RefValue:
+ return self.current
def set_current(self, new: _RefValue) -> _RefValue:
"""Set the current value and return what is now the old value
@@ -41,6 +55,8 @@ def set_current(self, new: _RefValue) -> _RefValue:
"""
old = self.current
self.current = new
+ if self.key:
+ self._hook.add_state_update(self)
return old
def __eq__(self, other: Any) -> bool:
diff --git a/src/py/reactpy/tests/test_core/test_layout.py b/src/py/reactpy/tests/test_core/test_layout.py
index cfb544758..276e36e2d 100644
--- a/src/py/reactpy/tests/test_core/test_layout.py
+++ b/src/py/reactpy/tests/test_core/test_layout.py
@@ -343,7 +343,7 @@ async def test_root_component_life_cycle_hook_is_garbage_collected():
def add_to_live_hooks(constructor):
def wrapper(*args, **kwargs):
result = constructor(*args, **kwargs)
- hook = reactpy.hooks.current_hook()
+ hook = reactpy.hooks.get_current_hook()
hook_id = id(hook)
live_hooks.add(hook_id)
finalize(hook, live_hooks.discard, hook_id)
@@ -375,7 +375,7 @@ async def test_life_cycle_hooks_are_garbage_collected():
def add_to_live_hooks(constructor):
def wrapper(*args, **kwargs):
result = constructor(*args, **kwargs)
- hook = reactpy.hooks.current_hook()
+ hook = reactpy.hooks.get_current_hook()
hook_id = id(hook)
live_hooks.add(hook_id)
finalize(hook, live_hooks.discard, hook_id)
@@ -625,7 +625,7 @@ def Outer():
@reactpy.component
def Inner(finalizer_id):
if finalizer_id not in registered_finalizers:
- hook = reactpy.hooks.current_hook()
+ hook = reactpy.hooks.get_current_hook()
finalize(hook, lambda: garbage_collect_items.append(finalizer_id))
registered_finalizers.add(finalizer_id)
return reactpy.html.div(finalizer_id)
diff --git a/src/py/reactpy/tests/tooling/hooks.py b/src/py/reactpy/tests/tooling/hooks.py
index 1926a93bc..b60040495 100644
--- a/src/py/reactpy/tests/tooling/hooks.py
+++ b/src/py/reactpy/tests/tooling/hooks.py
@@ -1,8 +1,8 @@
-from reactpy.core.hooks import current_hook, use_state
+from reactpy.core.hooks import get_current_hook, use_state
def use_force_render():
- return current_hook().schedule_render
+ return get_current_hook().schedule_render
def use_toggle(init=False):
From a7374d0fd440bf948f2a5c93d18f4f140ffa09c9 Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Mon, 4 Mar 2024 20:17:50 -0800
Subject: [PATCH 03/11] Follow up work (#3)
* Delete commented out code
* add serialization for timezone and timedelta
* don't show reconnecting layer on first attempt
* Only connect after layout update handlers are set up
* perf tweak that apparently was never saved
* deserialization as well for timezone
* timezone and timedelta
* alter z-index
---
.../@reactpy/client/src/components.tsx | 3 +-
.../@reactpy/client/src/reactpy-client.ts | 30 ++++++++++++++++---
src/py/reactpy/reactpy/backend/sanic.py | 10 -------
src/py/reactpy/reactpy/core/state_recovery.py | 25 ++++++++++++++--
src/py/reactpy/reactpy/core/vdom.py | 2 +-
5 files changed, 50 insertions(+), 20 deletions(-)
diff --git a/src/js/packages/@reactpy/client/src/components.tsx b/src/js/packages/@reactpy/client/src/components.tsx
index fd23d3a8a..0f1c1722d 100644
--- a/src/js/packages/@reactpy/client/src/components.tsx
+++ b/src/js/packages/@reactpy/client/src/components.tsx
@@ -29,13 +29,12 @@ export function Layout(props: { client: ReactPyClient }): JSX.Element {
useEffect(
() =>
- props.client.onMessage("layout-update", ({ path, model, state_vars }) => {
+ props.client.onLayoutUpdate((path: string, model: any) => {
if (path === "") {
Object.assign(currentModel, model);
} else {
setJsonPointer(currentModel, path, model);
}
- props.client.updateStateVars(state_vars);
forceUpdate();
}),
[currentModel, props.client],
diff --git a/src/js/packages/@reactpy/client/src/reactpy-client.ts b/src/js/packages/@reactpy/client/src/reactpy-client.ts
index c5018e9a5..c69db96bc 100644
--- a/src/js/packages/@reactpy/client/src/reactpy-client.ts
+++ b/src/js/packages/@reactpy/client/src/reactpy-client.ts
@@ -16,6 +16,8 @@ export interface ReactPyClient {
*/
onMessage(type: string, handler: (message: any) => void): () => void;
+ onLayoutUpdate(handler: (path: string, model: any) => void): void;
+
/**
* Send a message to the server.
*
@@ -43,6 +45,7 @@ export abstract class BaseReactPyClient implements ReactPyClient {
private resolveReady: (value: undefined) => void;
protected stateVars: object;
protected debugMessages: boolean;
+ protected layoutUpdateHandlers: Array<(path: string, model: any) => void> = [];
constructor() {
this.resolveReady = () => { };
@@ -59,6 +62,10 @@ export abstract class BaseReactPyClient implements ReactPyClient {
};
}
+ onLayoutUpdate(handler: (path: string, model: any) => void): void {
+ this.layoutUpdateHandlers.push(handler);
+ }
+
abstract sendMessage(message: any): void;
abstract loadModule(moduleName: string): Promise;
@@ -146,7 +153,8 @@ enum messageTypes {
isReady = "is-ready",
reconnectingCheck = "reconnecting-check",
clientState = "client-state",
- stateUpdate = "state-update"
+ stateUpdate = "state-update",
+ layoutUpdate = "layout-update",
};
export class SimpleReactPyClient
@@ -198,6 +206,10 @@ export class SimpleReactPyClient
this.onMessage(messageTypes.isReady, (msg) => { this.isReady = true; this.salt = msg.salt; });
this.onMessage(messageTypes.clientState, () => { this.sendClientState() });
this.onMessage(messageTypes.stateUpdate, (msg) => { this.updateClientState(msg.state_vars) });
+ this.onMessage(messageTypes.layoutUpdate, (msg) => {
+ this.updateClientState(msg.state_vars);
+ this.invokeLayoutUpdateHandlers(msg.path, msg.model);
+ })
this.reconnect()
@@ -211,7 +223,13 @@ export class SimpleReactPyClient
window.addEventListener('scroll', reconnectOnUserAction);
}
- showReconnectingGrayout() {
+ protected invokeLayoutUpdateHandlers(path: string, model: any) {
+ this.layoutUpdateHandlers.forEach(func => {
+ func(path, model);
+ });
+ }
+
+ protected showReconnectingGrayout() {
const overlay = document.createElement('div');
overlay.id = 'reactpy-reconnect-overlay';
@@ -229,7 +247,7 @@ export class SimpleReactPyClient
display: flex;
justify-content: center;
align-items: center;
- z-index: 1000;
+ z-index: 100000;
`;
pipeContainer.style.cssText = `
@@ -328,6 +346,10 @@ export class SimpleReactPyClient
const maxInterval = this.reconnectOptions?.maxInterval || 20000;
const maxRetries = this.reconnectOptions?.maxRetries || 20;
+ if (this.layoutUpdateHandlers.length == 0) {
+ setTimeout(() => { this.reconnect(onOpen, interval, connectionAttemptsRemaining, lastAttempt); }, 10);
+ return
+ }
if (connectionAttemptsRemaining <= 0) {
logger.warn("Giving up on reconnecting (hit retry limit)");
@@ -344,7 +366,7 @@ export class SimpleReactPyClient
this.shouldReconnect = true;
window.setTimeout(() => {
- if (!this.didReconnectingCallback && this.reconnectingCallback) {
+ if (!this.didReconnectingCallback && this.reconnectingCallback && maxRetries != connectionAttemptsRemaining) {
this.didReconnectingCallback = true;
this.reconnectingCallback();
}
diff --git a/src/py/reactpy/reactpy/backend/sanic.py b/src/py/reactpy/reactpy/backend/sanic.py
index bad90b072..e648747fa 100644
--- a/src/py/reactpy/reactpy/backend/sanic.py
+++ b/src/py/reactpy/reactpy/backend/sanic.py
@@ -193,16 +193,6 @@ async def model_stream(
),
constructor,
)
- # await serve_layout(
- # Layout(
- # ConnectionContext(
- # constructor(),
- # value=,
- # )
- # ),
- # send,
- # recv,
- # )
api_blueprint.add_websocket_route(
model_stream,
diff --git a/src/py/reactpy/reactpy/core/state_recovery.py b/src/py/reactpy/reactpy/core/state_recovery.py
index 38be33786..6979de7ff 100644
--- a/src/py/reactpy/reactpy/core/state_recovery.py
+++ b/src/py/reactpy/reactpy/core/state_recovery.py
@@ -1,5 +1,6 @@
import asyncio
import base64
+from dataclasses import asdict, is_dataclass
import datetime
import hashlib
import time
@@ -58,6 +59,8 @@ def __init__(
datetime.datetime,
datetime.date,
datetime.time,
+ datetime.timezone,
+ datetime.timedelta,
]
)
@@ -65,7 +68,7 @@ def _map_objects_to_ids(self, serializable_types: Iterable[type]) -> dict:
self._object_to_type_id = {}
self._type_id_to_object = {}
for idx, typ in enumerate(
- (None, bool, str, int, float, list, tuple, UUID, *serializable_types)
+ (None, bool, str, int, float, list, tuple, UUID, datetime.timezone, datetime.timedelta, *serializable_types)
):
idx_as_bytes = str(idx).encode("utf-8")
self._object_to_type_id[typ] = idx_as_bytes
@@ -132,8 +135,13 @@ def __init__(
self._type_id_to_object = type_id_to_object
self._max_object_length = max_object_length
self._max_num_state_objects = max_num_state_objects
- self._default_serializer = default_serializer
- self._deserializer_map = deserializer_map or {}
+ self._provided_default_serializer = default_serializer
+ deserialization_map = {
+ datetime.timezone: lambda x: datetime.timezone(
+ datetime.timedelta(**x["offset"]), x["name"]
+ ),
+ }
+ self._deserializer_map = deserialization_map | (deserializer_map or {})
def _get_otp_code(self, target_time: float) -> str:
at = self._totp.at
@@ -253,6 +261,17 @@ def _sign_serialization(
def _serialize_object(self, obj: Any) -> bytes:
return orjson.dumps(obj, default=self._default_serializer)
+ def _default_serializer(self, obj: Any) -> bytes:
+ if isinstance(obj, datetime.timezone):
+ return {"name": obj.tzname(None), "offset": obj.utcoffset(None)}
+ if isinstance(obj, datetime.timedelta):
+ return {"days": obj.days, "seconds": obj.seconds, "microseconds": obj.microseconds}
+ if is_dataclass(obj):
+ return asdict(obj)
+ if self._provided_default_serializer:
+ return self._provided_default_serializer(obj)
+ raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
+
def _do_deserialize(
self, typ: type, result: Any, custom_deserializer: Callable | None
) -> Any:
diff --git a/src/py/reactpy/reactpy/core/vdom.py b/src/py/reactpy/reactpy/core/vdom.py
index e494b5269..2bb3120dd 100644
--- a/src/py/reactpy/reactpy/core/vdom.py
+++ b/src/py/reactpy/reactpy/core/vdom.py
@@ -328,7 +328,7 @@ def _validate_child_key_integrity(value: Any) -> None:
)
else:
for child in value:
- if isinstance(child, ComponentType) and child.key is None:
+ if child.key is None and isinstance(child, ComponentType):
warn(f"Key not specified for child in list {child}", UserWarning)
elif isinstance(child, Mapping) and "key" not in child:
# remove 'children' to reduce log spam
From ef8537f451362809a80548fc96ef38ce50635923 Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Tue, 5 Mar 2024 12:29:49 -0800
Subject: [PATCH 04/11] add location to serialization (#4)
---
src/py/reactpy/reactpy/core/state_recovery.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/py/reactpy/reactpy/core/state_recovery.py b/src/py/reactpy/reactpy/core/state_recovery.py
index 6979de7ff..59c8c89da 100644
--- a/src/py/reactpy/reactpy/core/state_recovery.py
+++ b/src/py/reactpy/reactpy/core/state_recovery.py
@@ -14,6 +14,7 @@
import orjson
import pyotp
from more_itertools import chunked
+from reactpy.backend.types import Location
logger = getLogger(__name__)
@@ -61,6 +62,7 @@ def __init__(
datetime.time,
datetime.timezone,
datetime.timedelta,
+ Location,
]
)
From e23e43d6da24a3568f157853fd3ce81f96041961 Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Wed, 6 Mar 2024 11:01:29 -0800
Subject: [PATCH 05/11] Fix client not reloading on hash mismatch (#5)
---
src/py/reactpy/reactpy/core/serve.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/src/py/reactpy/reactpy/core/serve.py b/src/py/reactpy/reactpy/core/serve.py
index b7832b565..eef5c5c96 100644
--- a/src/py/reactpy/reactpy/core/serve.py
+++ b/src/py/reactpy/reactpy/core/serve.py
@@ -191,20 +191,20 @@ async def _do_state_rebuild_for_reconnection(self, layout: Layout) -> str:
client_state = serializer.deserialize_client_state(state_vars)
layout.reconnecting.set_current(True)
layout.client_state = client_state
- except StateRecoveryFailureError:
- logger.exception("State recovery failed")
- layout.reconnecting.set_current(False)
- layout.client_state = {}
- else:
+
salt = client_state_msg["salt"]
- try:
layout.start_rendering_for_reconnect()
await layout.render_until_queue_empty()
except StateRecoveryFailureError:
- logger.warning("Client state non-recoverable. Starting fresh")
+ logger.warning(
+ "State recovery failed (likely client from different version). Starting fresh"
+ )
await layout.finish()
+ layout.reconnecting.set_current(False)
+ layout.client_state = {}
await layout.start()
layout.start_rendering()
+ return salt
layout.reconnecting.set_current(False)
layout.client_state = {}
return salt
From d36a982bb144f4b03766c0fa14a5b8947acc2805 Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Wed, 6 Mar 2024 15:35:29 -0800
Subject: [PATCH 06/11] Fix user mouse / scroll activity not actually
triggering a reconnection (#6)
* Fix for client not reconnecting on activity
* Type improvements
* comment improvements
---
.../@reactpy/client/src/reactpy-client.ts | 38 +++++++++++--------
1 file changed, 23 insertions(+), 15 deletions(-)
diff --git a/src/js/packages/@reactpy/client/src/reactpy-client.ts b/src/js/packages/@reactpy/client/src/reactpy-client.ts
index c69db96bc..005a11473 100644
--- a/src/js/packages/@reactpy/client/src/reactpy-client.ts
+++ b/src/js/packages/@reactpy/client/src/reactpy-client.ts
@@ -145,8 +145,8 @@ type ReconnectProps = {
maxRetries?: number;
backoffRate?: number;
intervalJitter?: number;
- reconnectingCallback?: Function;
- reconnectedCallback?: Function;
+ reconnectingCallback?: () => void;
+ reconnectedCallback?: () => void;
};
enum messageTypes {
@@ -163,7 +163,7 @@ export class SimpleReactPyClient
private readonly urls: ServerUrls;
private socket!: { current?: WebSocket };
private idleDisconnectTimeMillis: number;
- private lastMessageTime: number;
+ private lastActivityTime: number;
private reconnectOptions: ReconnectProps | undefined;
private messageQueue: any[] = [];
private socketLoopIntervalId?: number | null;
@@ -174,9 +174,10 @@ export class SimpleReactPyClient
private salt: string;
private shouldReconnect: boolean;
private connectionTimeout: number;
- private reconnectingCallback: Function;
- private reconnectedCallback: Function;
+ private reconnectingCallback: () => void;
+ private reconnectedCallback: () => void;
private didReconnectingCallback: boolean;
+ private willReconnect: boolean;
constructor(props: SimpleReactPyClientProps) {
super();
@@ -190,11 +191,12 @@ export class SimpleReactPyClient
);
this.idleDisconnectTimeMillis = (props.idleDisconnectTimeSeconds || 240) * 1000;
this.connectionTimeout = props.connectionTimeout || 5000;
- this.lastMessageTime = Date.now()
+ this.lastActivityTime = Date.now()
this.reconnectOptions = props.reconnectOptions
this.debugMessages = props.debugMessages || false;
this.sleeping = false;
this.isReconnecting = false;
+ this.willReconnect = false;
this.isReady = false
this.salt = "";
this.shouldReconnect = false;
@@ -209,18 +211,21 @@ export class SimpleReactPyClient
this.onMessage(messageTypes.layoutUpdate, (msg) => {
this.updateClientState(msg.state_vars);
this.invokeLayoutUpdateHandlers(msg.path, msg.model);
+ this.willReconnect = true; // don't indicate a reconnect until at least one successful layout update
})
this.reconnect()
- const reconnectOnUserAction = (ev: any) => {
+ const handleUserAction = (ev: any) => {
+ this.lastActivityTime = Date.now();
if (!this.isReady && !this.isReconnecting) {
+ this.sleeping = false;
this.reconnect();
}
}
- window.addEventListener('mousemove', reconnectOnUserAction);
- window.addEventListener('scroll', reconnectOnUserAction);
+ window.addEventListener('mousemove', handleUserAction);
+ window.addEventListener('scroll', handleUserAction);
}
protected invokeLayoutUpdateHandlers(path: string, model: any) {
@@ -290,7 +295,7 @@ export class SimpleReactPyClient
}
indicateReconnect(): void {
- const isReconnecting = this.isReconnecting ? "yes" : "no";
+ const isReconnecting = this.willReconnect ? "yes" : "no";
this.sendMessage({ "type": messageTypes.reconnectingCheck, "value": isReconnecting }, true)
}
@@ -324,6 +329,7 @@ export class SimpleReactPyClient
if (this.debugMessages) {
logger.log("Sending message", message);
}
+ this.lastActivityTime = Date.now();
this.socket.current.send(JSON.stringify(message));
}
}
@@ -331,10 +337,11 @@ export class SimpleReactPyClient
idleTimeoutCheck(): void {
if (!this.socket)
return;
- if (Date.now() - this.lastMessageTime > this.idleDisconnectTimeMillis) {
+ if (Date.now() - this.lastActivityTime > this.idleDisconnectTimeMillis) {
if (this.socket.current && this.socket.current.readyState === WebSocket.OPEN) {
logger.warn("Closing socket connection due to idle activity");
this.sleeping = true;
+ this.isReconnecting = false;
this.socket.current.close();
}
}
@@ -388,14 +395,15 @@ export class SimpleReactPyClient
onOpen();
},
onClose: () => {
- // reset retry interval
+ // reset retry interval on successful connection
if (Date.now() - lastAttempt > maxInterval * 2) {
interval = 750;
connectionAttemptsRemaining = maxRetries;
+ } else if (!this.sleeping) {
+ this.isReconnecting = true;
}
lastAttempt = Date.now()
this.shouldReconnect = false;
- this.isReconnecting = true;
this.isReady = false;
if (this.socketLoopIntervalId)
clearInterval(this.socketLoopIntervalId);
@@ -410,7 +418,7 @@ export class SimpleReactPyClient
this.reconnect(onOpen, thisInterval, newRetriesRemaining, lastAttempt);
}
},
- onMessage: async ({ data }) => { this.lastMessageTime = Date.now(); this.handleIncoming(JSON.parse(data)) },
+ onMessage: async ({ data }) => { this.lastActivityTime = Date.now(); this.handleIncoming(JSON.parse(data)) },
...this.reconnectOptions,
});
this.socketLoopIntervalId = window.setInterval(() => { this.socketLoop() }, 30);
@@ -431,7 +439,7 @@ export class SimpleReactPyClient
} else {
this.messageQueue.push(message);
}
- this.lastMessageTime = Date.now()
+ this.lastActivityTime = Date.now()
this.sleeping = false;
this.ensureConnected();
}
From 6118197685ee7a454d8043656cd855533c5c3c9d Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Thu, 7 Mar 2024 16:57:37 -0800
Subject: [PATCH 07/11] fix state recovery error not propagating (#7)
Fix Exception catch-all in render swallowing state recovery errors
Fix serialization errors not resulting in a state recovery error
---
src/py/reactpy/reactpy/core/layout.py | 7 ++++++-
src/py/reactpy/reactpy/core/serve.py | 5 ++++-
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/src/py/reactpy/reactpy/core/layout.py b/src/py/reactpy/reactpy/core/layout.py
index 8db467e30..f2274c6a8 100644
--- a/src/py/reactpy/reactpy/core/layout.py
+++ b/src/py/reactpy/reactpy/core/layout.py
@@ -47,7 +47,10 @@
)
from reactpy.core.component import Component
from reactpy.core.hooks import _ContextProvider
-from reactpy.core.state_recovery import StateRecoverySerializer
+from reactpy.core.state_recovery import (
+ StateRecoveryFailureError,
+ StateRecoverySerializer,
+)
from reactpy.core.types import (
ComponentType,
EventHandlerDict,
@@ -284,6 +287,8 @@ async def _render_component(
# components are given a node in the tree some other way
wrapper_model: VdomDict = {"tagName": "", "children": [raw_model]}
await self._render_model(exit_stack, old_state, new_state, wrapper_model)
+ except StateRecoveryFailureError:
+ raise
except Exception as error:
logger.exception(f"Failed to render {component}")
new_state.model.current = {
diff --git a/src/py/reactpy/reactpy/core/serve.py b/src/py/reactpy/reactpy/core/serve.py
index eef5c5c96..85799c762 100644
--- a/src/py/reactpy/reactpy/core/serve.py
+++ b/src/py/reactpy/reactpy/core/serve.py
@@ -188,7 +188,10 @@ async def _do_state_rebuild_for_reconnection(self, layout: Layout) -> str:
serializer = self._state_recovery_manager.create_serializer(
client_state_msg["salt"]
)
- client_state = serializer.deserialize_client_state(state_vars)
+ try:
+ client_state = serializer.deserialize_client_state(state_vars)
+ except Exception as err:
+ raise StateRecoveryFailureError() from err
layout.reconnecting.set_current(True)
layout.client_state = client_state
From 09eb44cc679f0364a8f0819127ca0ff87bbbde62 Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Fri, 8 Mar 2024 07:25:29 -0800
Subject: [PATCH 08/11] Move socket loop interval to a variable (#8)
---
src/js/packages/@reactpy/client/src/reactpy-client.ts | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/src/js/packages/@reactpy/client/src/reactpy-client.ts b/src/js/packages/@reactpy/client/src/reactpy-client.ts
index 005a11473..b840479a0 100644
--- a/src/js/packages/@reactpy/client/src/reactpy-client.ts
+++ b/src/js/packages/@reactpy/client/src/reactpy-client.ts
@@ -107,6 +107,7 @@ export type SimpleReactPyClientProps = {
idleDisconnectTimeSeconds?: number;
connectionTimeout?: number;
debugMessages?: boolean;
+ socketLoopThrottle?: number;
};
/**
@@ -178,6 +179,7 @@ export class SimpleReactPyClient
private reconnectedCallback: () => void;
private didReconnectingCallback: boolean;
private willReconnect: boolean;
+ private socketLoopThrottle: number;
constructor(props: SimpleReactPyClientProps) {
super();
@@ -194,6 +196,7 @@ export class SimpleReactPyClient
this.lastActivityTime = Date.now()
this.reconnectOptions = props.reconnectOptions
this.debugMessages = props.debugMessages || false;
+ this.socketLoopThrottle = props.socketLoopThrottle || 5;
this.sleeping = false;
this.isReconnecting = false;
this.willReconnect = false;
@@ -421,7 +424,7 @@ export class SimpleReactPyClient
onMessage: async ({ data }) => { this.lastActivityTime = Date.now(); this.handleIncoming(JSON.parse(data)) },
...this.reconnectOptions,
});
- this.socketLoopIntervalId = window.setInterval(() => { this.socketLoop() }, 30);
+ this.socketLoopIntervalId = window.setInterval(() => { this.socketLoop() }, this.socketLoopThrottle);
this.idleCheckIntervalId = window.setInterval(() => { this.idleTimeoutCheck() }, 10000);
}, interval)
From 5f5562f8ac18a00229d5cd612de48dfbd41d136c Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Tue, 12 Mar 2024 10:29:08 -0700
Subject: [PATCH 09/11] mypy typing updates (#9)
---
src/py/reactpy/reactpy/core/component.py | 6 +++++-
src/py/reactpy/reactpy/core/hooks.py | 6 +++---
2 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py
index 9d4955546..f11bada12 100644
--- a/src/py/reactpy/reactpy/core/component.py
+++ b/src/py/reactpy/reactpy/core/component.py
@@ -2,7 +2,7 @@
import inspect
from functools import wraps
-from typing import Any, Callable, ParamSpec, TypeVar
+from typing import Any, Callable, ParamSpec, TypeVar, overload
from reactpy.core.types import ComponentType, VdomDict
@@ -10,6 +10,10 @@
P = ParamSpec("P")
+@overload
+def component(function: None = None, *, priority: int) -> Callable[P, Component]: ...
+
+
def component(
function: Callable[P, T] | None = None,
*,
diff --git a/src/py/reactpy/reactpy/core/hooks.py b/src/py/reactpy/reactpy/core/hooks.py
index 72f7c2c4e..95fefaf94 100644
--- a/src/py/reactpy/reactpy/core/hooks.py
+++ b/src/py/reactpy/reactpy/core/hooks.py
@@ -54,11 +54,11 @@ class ReconnectingOnly(list):
@overload
-def use_state(initial_value: Callable[[], _Type]) -> State[_Type]: ...
+def use_state(initial_value: Callable[[], _Type], *, server_only: bool = False) -> State[_Type]: ...
@overload
-def use_state(initial_value: _Type) -> State[_Type]: ...
+def use_state(initial_value: _Type, *, server_only: bool = False) -> State[_Type]: ...
def use_state(
@@ -509,7 +509,7 @@ def empty(self) -> bool:
return False
-def use_ref(initial_value: _Type, server_only: bool = True) -> Ref[_Type]:
+def use_ref(initial_value: _Type, *, server_only: bool = True) -> Ref[_Type]:
"""See the full :ref:`Use State` docs for details
Parameters:
From 6203da576411f864e3e173e343a363a5a46e23e6 Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Tue, 12 Mar 2024 10:34:34 -0700
Subject: [PATCH 10/11] add overload (#10)
---
src/py/reactpy/reactpy/core/component.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py
index f11bada12..71206b94d 100644
--- a/src/py/reactpy/reactpy/core/component.py
+++ b/src/py/reactpy/reactpy/core/component.py
@@ -14,6 +14,10 @@
def component(function: None = None, *, priority: int) -> Callable[P, Component]: ...
+@overload
+def component(function: Callable[P, T] | None) -> Callable[P, Component]: ...
+
+
def component(
function: Callable[P, T] | None = None,
*,
From 9704cbccc189bd9cdfac88bbc84d8c5f9bd46af3 Mon Sep 17 00:00:00 2001
From: James Hutchison <122519877+JamesHutchison@users.noreply.github.com>
Date: Tue, 12 Mar 2024 11:00:33 -0700
Subject: [PATCH 11/11] finally fix priority typing (#11)
---
src/py/reactpy/reactpy/core/component.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/src/py/reactpy/reactpy/core/component.py b/src/py/reactpy/reactpy/core/component.py
index 71206b94d..8ffaae1cf 100644
--- a/src/py/reactpy/reactpy/core/component.py
+++ b/src/py/reactpy/reactpy/core/component.py
@@ -11,11 +11,13 @@
@overload
-def component(function: None = None, *, priority: int) -> Callable[P, Component]: ...
+def component(
+ function: None = None, *, priority: int
+) -> Callable[[Callable[P, T]], Callable[P, Component]]: ...
@overload
-def component(function: Callable[P, T] | None) -> Callable[P, Component]: ...
+def component(function: Callable[P, T]) -> Callable[P, Component]: ...
def component(