diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index de6d598342..519960b7a5 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -336,7 +336,7 @@ def sample_blackjax_nuts( var_names: Optional[Sequence[str]] = None, keep_untransformed: bool = False, chain_method: str = "parallel", - postprocessing_backend: Literal["cpu", "gpu"] | None = None, + postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", idata_kwargs: Optional[Dict[str, Any]] = None, postprocessing_chunks=None, # deprecated @@ -546,7 +546,7 @@ def sample_numpyro_nuts( progressbar: bool = True, keep_untransformed: bool = False, chain_method: str = "parallel", - postprocessing_backend: Literal["cpu", "gpu"] | None = None, + postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", idata_kwargs: Optional[Dict] = None, nuts_kwargs: Optional[Dict] = None,