diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index cf37ccc779..4429a3abcc 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -14,7 +14,7 @@ jobs: - uses: actions/checkout@9bb56186c3b09b4f86b1c65136769dd318469633 - uses: actions/setup-python@v5 with: - python-version: "3.9" # Run pre-commit on oldest supported Python version + python-version: "3.10" # Run pre-commit on oldest supported Python version - uses: pre-commit/action@v3.0.1 mypy: runs-on: ubuntu-latest @@ -52,7 +52,7 @@ jobs: activate-environment: pymc-test channel-priority: strict environment-file: conda-envs/environment-test.yml - python-version: "3.9" # Run pre-commit on oldest supported Python version + python-version: "3.10" # Run pre-commit on oldest supported Python version use-mamba: true use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267 - name: Install-pymc and mypy dependencies diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b31fd0538e..55ccc3d2e0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -188,7 +188,7 @@ jobs: matrix: os: [windows-latest] floatx: [float64] - python-version: ["3.9"] + python-version: ["3.10"] test-subset: - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py - tests/model/test_core.py tests/sampling/test_mcmc.py diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index a3f4a41a8c..db4f5840cb 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -13,7 +13,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.19,<2.20 +- pytensor>=2.20,<2.21 - python-graphviz - networkx - scipy>=1.4.1 diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index d50328df26..3609ae618c 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -11,7 +11,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.19,<2.20 +- pytensor>=2.20,<2.21 - python-graphviz - rich>=13.7.1 - scipy>=1.4.1 diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml index 3437904801..0ef618dd2c 100644 --- a/conda-envs/environment-jax.yml +++ b/conda-envs/environment-jax.yml @@ -20,7 +20,7 @@ dependencies: - numpyro>=0.8.0 - pandas>=0.24.0 - pip -- pytensor>=2.19,<2.20 +- pytensor>=2.20,<2.21 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 6c0c2a0b61..ab6ce313fd 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -16,7 +16,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.19,<2.20 +- pytensor>=2.20,<2.21 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 91df7bfbac..253794abf4 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -13,7 +13,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.19,<2.20 +- pytensor>=2.20,<2.21 - python-graphviz - networkx - rich>=13.7.1 diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index aaa958e985..f851b9a357 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -16,7 +16,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - pip -- pytensor>=2.19,<2.20 +- pytensor>=2.20,<2.21 - python-graphviz - networkx - rich>=13.7.1 diff --git a/mypy.ini b/mypy.ini index ef28231ab4..5433d9bf16 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.9 +python_version = 3.10 no_implicit_optional = False strict_optional = True warn_redundant_casts = False diff --git a/pymc/_version.py b/pymc/_version.py index ad73ee06a2..263391ecda 100644 --- a/pymc/_version.py +++ b/pymc/_version.py @@ -29,7 +29,7 @@ import subprocess import sys -from typing import Callable +from collections.abc import Callable def get_keywords(): diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index b63f68acc6..1347f55838 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -63,12 +63,11 @@ from collections.abc import Mapping, Sequence from copy import copy -from typing import Optional, Union +from typing import Optional, TypeAlias, Union import numpy as np from pytensor.tensor.variable import TensorVariable -from typing_extensions import TypeAlias from pymc.backends.arviz import predictions_to_inference_data, to_inference_data from pymc.backends.base import BaseTrace, IBaseTrace @@ -82,7 +81,7 @@ from pymc.backends.mcbackend import init_chain_adapters - TraceOrBackend = Union[BaseTrace, Backend] + TraceOrBackend = BaseTrace | Backend RunType: TypeAlias = Run HAS_MCB = True except ImportError: @@ -98,9 +97,9 @@ def _init_trace( expected_length: int, chain_number: int, stats_dtypes: list[dict[str, type]], - trace: Optional[BaseTrace], + trace: BaseTrace | None, model: Model, - trace_vars: Optional[list[TensorVariable]] = None, + trace_vars: list[TensorVariable] | None = None, ) -> BaseTrace: """Initializes a trace backend for a chain.""" strace: BaseTrace @@ -119,14 +118,14 @@ def _init_trace( def init_traces( *, - backend: Optional[TraceOrBackend], + backend: TraceOrBackend | None, chains: int, expected_length: int, - step: Union[BlockedStep, CompoundStep], + step: BlockedStep | CompoundStep, initial_point: Mapping[str, np.ndarray], model: Model, - trace_vars: Optional[list[TensorVariable]] = None, -) -> tuple[Optional[RunType], Sequence[IBaseTrace]]: + trace_vars: list[TensorVariable] | None = None, +) -> tuple[RunType | None, Sequence[IBaseTrace]]: """Initializes a trace recorder for each chain.""" if HAS_MCB and isinstance(backend, Backend): return init_chain_adapters( diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 199e27a137..4f32fe65cf 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -78,7 +78,7 @@ def is_data(name, var, model) -> bool: and var not in model.potentials and var not in model.value_vars and name not in observations - and isinstance(var, (Constant, SharedVariable)) + and isinstance(var, Constant | SharedVariable) ) # The assumption is that constants (like pm.Data) are named @@ -163,10 +163,10 @@ def insert(self, k: str, v, idx: int): class InferenceDataConverter: """Encapsulate InferenceData specific logic.""" - model: Optional[Model] = None - posterior_predictive: Optional[Mapping[str, np.ndarray]] = None - predictions: Optional[Mapping[str, np.ndarray]] = None - prior: Optional[Mapping[str, np.ndarray]] = None + model: Model | None = None + posterior_predictive: Mapping[str, np.ndarray] | None = None + predictions: Mapping[str, np.ndarray] | None = None + prior: Mapping[str, np.ndarray] | None = None def __init__( self, @@ -177,11 +177,11 @@ def __init__( log_likelihood=False, log_prior=False, predictions=None, - coords: Optional[CoordSpec] = None, - dims: Optional[DimSpec] = None, - sample_dims: Optional[list] = None, + coords: CoordSpec | None = None, + dims: DimSpec | None = None, + sample_dims: list | None = None, model=None, - save_warmup: Optional[bool] = None, + save_warmup: bool | None = None, include_transformed: bool = False, ): self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup @@ -466,15 +466,15 @@ def to_inference_data(self): def to_inference_data( trace: Optional["MultiTrace"] = None, *, - prior: Optional[Mapping[str, Any]] = None, - posterior_predictive: Optional[Mapping[str, Any]] = None, - log_likelihood: Union[bool, Iterable[str]] = False, - log_prior: Union[bool, Iterable[str]] = False, - coords: Optional[CoordSpec] = None, - dims: Optional[DimSpec] = None, - sample_dims: Optional[list] = None, + prior: Mapping[str, Any] | None = None, + posterior_predictive: Mapping[str, Any] | None = None, + log_likelihood: bool | Iterable[str] = False, + log_prior: bool | Iterable[str] = False, + coords: CoordSpec | None = None, + dims: DimSpec | None = None, + sample_dims: list | None = None, model: Optional["Model"] = None, - save_warmup: Optional[bool] = None, + save_warmup: bool | None = None, include_transformed: bool = False, ) -> InferenceData: """Convert pymc data into an InferenceData object. @@ -543,10 +543,10 @@ def predictions_to_inference_data( predictions, posterior_trace: Optional["MultiTrace"] = None, model: Optional["Model"] = None, - coords: Optional[CoordSpec] = None, - dims: Optional[DimSpec] = None, - sample_dims: Optional[list] = None, - idata_orig: Optional[InferenceData] = None, + coords: CoordSpec | None = None, + dims: DimSpec | None = None, + sample_dims: list | None = None, + idata_orig: InferenceData | None = None, inplace: bool = False, ) -> InferenceData: """Translate out-of-sample predictions into ``InferenceData``. diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 764896cf4a..7854cc0931 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -25,9 +25,7 @@ from collections.abc import Mapping, Sequence, Sized from typing import ( Any, - Optional, TypeVar, - Union, cast, ) @@ -53,7 +51,7 @@ class IBaseTrace(ABC, Sized): varnames: list[str] """Names of tracked variables.""" - sampler_vars: list[dict[str, Union[type, np.dtype]]] + sampler_vars: list[dict[str, type | np.dtype]] """Sampler stats for each sampler.""" def __len__(self): @@ -75,7 +73,7 @@ def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray: raise NotImplementedError() def get_sampler_stats( - self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1 + self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1 ) -> np.ndarray: """Get sampler statistics from the trace. @@ -219,7 +217,7 @@ def __getitem__(self, idx): raise ValueError("Can only index with slice or integer") def get_sampler_stats( - self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1 + self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1 ) -> np.ndarray: """Get sampler statistics from the trace. @@ -443,7 +441,7 @@ def get_values( burn: int = 0, thin: int = 1, combine: bool = True, - chains: Optional[Union[int, Sequence[int]]] = None, + chains: int | Sequence[int] | None = None, squeeze: bool = True, ) -> list[np.ndarray]: """Get values from traces. @@ -482,9 +480,9 @@ def get_sampler_stats( burn: int = 0, thin: int = 1, combine: bool = True, - chains: Optional[Union[int, Sequence[int]]] = None, + chains: int | Sequence[int] | None = None, squeeze: bool = True, - ) -> Union[list[np.ndarray], np.ndarray]: + ) -> list[np.ndarray] | np.ndarray: """Get sampler statistics from the trace. Note: This implementation attempts to squeeze object arrays into a consistent dtype, @@ -534,7 +532,7 @@ def _slice(self, slice: slice): trace._report = self._report._slice(*idxs) return trace - def point(self, idx: int, chain: Optional[int] = None) -> dict[str, np.ndarray]: + def point(self, idx: int, chain: int | None = None) -> dict[str, np.ndarray]: """Return a dictionary of point values at `idx`. Parameters diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index 32f6d8f34c..f60b4e293e 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -17,7 +17,7 @@ import pickle from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union, cast +from typing import Any, cast import hagelkorn import mcbackend as mcb @@ -144,7 +144,7 @@ def _get_stats(self, fname: str, slc: slice) -> np.ndarray: return values def get_sampler_stats( - self, stat_name: str, sampler_idx: Optional[int] = None, burn=0, thin=1 + self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1 ) -> np.ndarray: slc = slice(burn, None, thin) # When there's just one sampler, default to remove the sampler dimension @@ -204,7 +204,7 @@ def point(self, idx: int) -> dict[str, np.ndarray]: def make_runmeta_and_point_fn( *, initial_point: Mapping[str, np.ndarray], - step: Union[CompoundStep, BlockedStep], + step: CompoundStep | BlockedStep, model: Model, ) -> tuple[mcb.RunMeta, PointFunc]: variables, point_fn = get_variables_and_point_fn(model, initial_point) @@ -254,7 +254,7 @@ def init_chain_adapters( backend: mcb.Backend, chains: int, initial_point: Mapping[str, np.ndarray], - step: Union[CompoundStep, BlockedStep], + step: CompoundStep | BlockedStep, model: Model, ) -> tuple[mcb.Run, list[ChainRecordAdapter]]: """Create an McBackend metadata description for the MCMC run. diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index ec5cef1c9b..23f05488b9 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -17,7 +17,7 @@ Store sampling values in memory as a NumPy array. """ -from typing import Any, Optional +from typing import Any import numpy as np @@ -210,7 +210,7 @@ def _slice_as_ndarray(strace, idx): def point_list_to_multitrace( - point_list: list[dict[str, np.ndarray]], model: Optional[Model] = None + point_list: list[dict[str, np.ndarray]], model: Model | None = None ) -> MultiTrace: """transform point list into MultiTrace""" _model = modelcontext(model) diff --git a/pymc/backends/report.py b/pymc/backends/report.py index b6548914d0..49e584a979 100644 --- a/pymc/backends/report.py +++ b/pymc/backends/report.py @@ -16,8 +16,6 @@ import itertools import logging -from typing import Optional - from pymc.stats.convergence import _LEVELS, SamplerWarning logger = logging.getLogger(__name__) @@ -44,17 +42,17 @@ def ok(self): return all(_LEVELS[warn.level] < _LEVELS["warn"] for warn in self._warnings) @property - def n_tune(self) -> Optional[int]: + def n_tune(self) -> int | None: """Number of tune iterations - not necessarily kept in trace!""" return self._n_tune @property - def n_draws(self) -> Optional[int]: + def n_draws(self) -> int | None: """Number of draw iterations.""" return self._n_draws @property - def t_sampling(self) -> Optional[float]: + def t_sampling(self) -> float | None: """ Number of seconds that the sampling procedure took. diff --git a/pymc/blocking.py b/pymc/blocking.py index 443a0ddd93..287a06d530 100644 --- a/pymc/blocking.py +++ b/pymc/blocking.py @@ -20,22 +20,18 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial from typing import ( Any, - Callable, Generic, NamedTuple, - Optional, + TypeAlias, TypeVar, - Union, ) import numpy as np -from typing_extensions import TypeAlias - __all__ = ["DictToArrayBijection"] @@ -43,8 +39,8 @@ PointType: TypeAlias = dict[str, np.ndarray] StatsDict: TypeAlias = dict[str, Any] StatsType: TypeAlias = list[StatsDict] -StatDtype: TypeAlias = Union[type, np.dtype] -StatShape: TypeAlias = Optional[Sequence[Optional[int]]] +StatDtype: TypeAlias = type | np.dtype +StatShape: TypeAlias = Sequence[int | None] | None # `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for diff --git a/pymc/data.py b/pymc/data.py index 576cad6b11..15aa15e5e3 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -18,7 +18,7 @@ from collections.abc import Sequence from copy import copy -from typing import Optional, Union, cast +from typing import cast import numpy as np import pandas as pd @@ -203,10 +203,10 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: def determine_coords( model, - value: Union[pd.DataFrame, pd.Series, xr.DataArray], - dims: Optional[Sequence[Optional[str]]] = None, - coords: Optional[dict[str, Union[Sequence, np.ndarray]]] = None, -) -> tuple[dict[str, Union[Sequence, np.ndarray]], Sequence[Optional[str]]]: + value: pd.DataFrame | pd.Series | xr.DataArray, + dims: Sequence[str | None] | None = None, + coords: dict[str, Sequence | np.ndarray] | None = None, +) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]: """Determines coordinate values from data or the model (via ``dims``).""" if coords is None: coords = {} @@ -260,8 +260,8 @@ def ConstantData( name: str, value, *, - dims: Optional[Sequence[str]] = None, - coords: Optional[dict[str, Union[Sequence, np.ndarray]]] = None, + dims: Sequence[str] | None = None, + coords: dict[str, Sequence | np.ndarray] | None = None, infer_dims_and_coords=False, **kwargs, ) -> TensorConstant: @@ -290,8 +290,8 @@ def MutableData( name: str, value, *, - dims: Optional[Sequence[str]] = None, - coords: Optional[dict[str, Union[Sequence, np.ndarray]]] = None, + dims: Sequence[str] | None = None, + coords: dict[str, Sequence | np.ndarray] | None = None, infer_dims_and_coords=False, **kwargs, ) -> SharedVariable: @@ -320,12 +320,12 @@ def Data( name: str, value, *, - dims: Optional[Sequence[str]] = None, - coords: Optional[dict[str, Union[Sequence, np.ndarray]]] = None, + dims: Sequence[str] | None = None, + coords: dict[str, Sequence | np.ndarray] | None = None, infer_dims_and_coords=False, - mutable: Optional[bool] = None, + mutable: bool | None = None, **kwargs, -) -> Union[SharedVariable, TensorConstant]: +) -> SharedVariable | TensorConstant: """Data container that registers a data variable with the model. Depending on the ``mutable`` setting (default: True), the variable diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 8b717d3c24..18b45ce821 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -89,7 +89,7 @@ class Censored(Distribution): @classmethod def dist(cls, dist, lower, upper, **kwargs): if not isinstance(dist, TensorVariable) or not isinstance( - dist.owner.op, (RandomVariable, SymbolicRandomVariable) + dist.owner.op, RandomVariable | SymbolicRandomVariable ): raise ValueError( f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}" diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index b83260f671..2074dbcd54 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -22,8 +22,6 @@ import warnings -from typing import Optional, Union - import numpy as np import pytensor import pytensor.tensor as pt @@ -150,7 +148,7 @@ class BoundedContinuous(Continuous): """Base class for bounded continuous distributions""" # Indices of the arguments that define the lower and upper bounds of the distribution - bound_args_indices: Optional[list[int]] = None + bound_args_indices: list[int] | None = None @_default_transform.register(PositiveContinuous) @@ -553,11 +551,11 @@ class TruncatedNormalRV(RandomVariable): def rng_fn( cls, rng: np.random.RandomState, - mu: Union[np.ndarray, float], - sigma: Union[np.ndarray, float], - lower: Union[np.ndarray, float], - upper: Union[np.ndarray, float], - size: Optional[Union[list[int], int]], + mu: np.ndarray | float, + sigma: np.ndarray | float, + lower: np.ndarray | float, + upper: np.ndarray | float, + size: list[int] | int | None, ) -> np.ndarray: # Upcast to float64. (Caller will downcast to desired dtype if needed) # (Work-around for https://github.com/scipy/scipy/issues/15928) @@ -657,12 +655,12 @@ class TruncatedNormal(BoundedContinuous): @classmethod def dist( cls, - mu: Optional[DIST_PARAMETER_TYPES] = 0, - sigma: Optional[DIST_PARAMETER_TYPES] = None, + mu: DIST_PARAMETER_TYPES | None = 0, + sigma: DIST_PARAMETER_TYPES | None = None, *, - tau: Optional[DIST_PARAMETER_TYPES] = None, - lower: Optional[DIST_PARAMETER_TYPES] = None, - upper: Optional[DIST_PARAMETER_TYPES] = None, + tau: DIST_PARAMETER_TYPES | None = None, + lower: DIST_PARAMETER_TYPES | None = None, + upper: DIST_PARAMETER_TYPES | None = None, **kwargs, ) -> RandomVariable: tau, sigma = get_tau_sigma(tau=tau, sigma=sigma) @@ -837,8 +835,8 @@ class HalfNormal(PositiveContinuous): @classmethod def dist( cls, - sigma: Optional[DIST_PARAMETER_TYPES] = None, - tau: Optional[DIST_PARAMETER_TYPES] = None, + sigma: DIST_PARAMETER_TYPES | None = None, + tau: DIST_PARAMETER_TYPES | None = None, *args, **kwargs, ): @@ -981,10 +979,10 @@ class Wald(PositiveContinuous): @classmethod def dist( cls, - mu: Optional[DIST_PARAMETER_TYPES] = None, - lam: Optional[DIST_PARAMETER_TYPES] = None, - phi: Optional[DIST_PARAMETER_TYPES] = None, - alpha: Optional[DIST_PARAMETER_TYPES] = 0.0, + mu: DIST_PARAMETER_TYPES | None = None, + lam: DIST_PARAMETER_TYPES | None = None, + phi: DIST_PARAMETER_TYPES | None = None, + alpha: DIST_PARAMETER_TYPES | None = 0.0, **kwargs, ): mu, lam, phi = cls.get_mu_lam_phi(mu, lam, phi) @@ -1155,11 +1153,11 @@ class Beta(UnitContinuous): @classmethod def dist( cls, - alpha: Optional[DIST_PARAMETER_TYPES] = None, - beta: Optional[DIST_PARAMETER_TYPES] = None, - mu: Optional[DIST_PARAMETER_TYPES] = None, - sigma: Optional[DIST_PARAMETER_TYPES] = None, - nu: Optional[DIST_PARAMETER_TYPES] = None, + alpha: DIST_PARAMETER_TYPES | None = None, + beta: DIST_PARAMETER_TYPES | None = None, + mu: DIST_PARAMETER_TYPES | None = None, + sigma: DIST_PARAMETER_TYPES | None = None, + nu: DIST_PARAMETER_TYPES | None = None, *args, **kwargs, ): diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 22ed4d3400..739d05a348 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -18,9 +18,9 @@ import warnings from abc import ABCMeta -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import singledispatch -from typing import Callable, Optional, Union +from typing import TypeAlias import numpy as np @@ -41,7 +41,6 @@ from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.variable import TensorVariable -from typing_extensions import TypeAlias from pymc.distributions.shape_utils import ( Dims, @@ -79,9 +78,9 @@ "SymbolicRandomVariable", ] -DIST_PARAMETER_TYPES: TypeAlias = Union[np.ndarray, int, float, TensorVariable] +DIST_PARAMETER_TYPES: TypeAlias = np.ndarray | int | float | TensorVariable -vectorized_ppc: contextvars.ContextVar[Optional[Callable]] = contextvars.ContextVar( +vectorized_ppc: contextvars.ContextVar[Callable | None] = contextvars.ContextVar( "vectorized_ppc", default=None ) @@ -103,7 +102,7 @@ def rewrite_support_point_scan_node(self, node): for nd in local_fgraph_topo: if nd not in to_replace_set and isinstance( - nd.op, (RandomVariable, SymbolicRandomVariable) + nd.op, RandomVariable | SymbolicRandomVariable ): replace_with_support_point.append(nd.out) to_replace_set.add(nd) @@ -133,7 +132,7 @@ def add_requirements(self, fgraph): def apply(self, fgraph): for node in fgraph.toposort(): - if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): + if isinstance(node.op, RandomVariable | SymbolicRandomVariable): fgraph.replace(node.out, support_point(node.out)) elif isinstance(node.op, Scan): new_node = self.rewrite_support_point_scan_node(node) @@ -263,7 +262,7 @@ class SymbolicRandomVariable(OpFromGraph): (0 for scalar, 1 for vector, ...) """ - ndims_params: Optional[Sequence[int]] = None + ndims_params: Sequence[int] | None = None """Number of core dimensions of the distribution's parameters.""" signature: str = None @@ -327,7 +326,7 @@ def __new__( name: str, *args, rng=None, - dims: Optional[Dims] = None, + dims: Dims | None = None, initval=None, observed=None, total_size=None, @@ -436,7 +435,7 @@ def dist( cls, dist_params, *, - shape: Optional[Shape] = None, + shape: Shape | None = None, **kwargs, ) -> TensorVariable: """Creates a tensor variable corresponding to the `cls` distribution. @@ -591,13 +590,13 @@ class _CustomDist(Distribution): def dist( cls, *dist_params, - logp: Optional[Callable] = None, - logcdf: Optional[Callable] = None, - random: Optional[Callable] = None, - support_point: Optional[Callable] = None, - ndim_supp: Optional[int] = None, - ndims_params: Optional[Sequence[int]] = None, - signature: Optional[str] = None, + logp: Callable | None = None, + logcdf: Callable | None = None, + random: Callable | None = None, + support_point: Callable | None = None, + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, dtype: str = "floatX", class_name: str = "CustomDist", **kwargs, @@ -652,10 +651,10 @@ def dist( def rv_op( cls, *dist_params, - logp: Optional[Callable], - logcdf: Optional[Callable], - random: Optional[Callable], - support_point: Optional[Callable], + logp: Callable | None, + logcdf: Callable | None, + random: Callable | None, + support_point: Callable | None, ndim_supp: int, ndims_params: Sequence[int], dtype: str, @@ -743,12 +742,12 @@ def dist( cls, *dist_params, dist: Callable, - logp: Optional[Callable] = None, - logcdf: Optional[Callable] = None, - support_point: Optional[Callable] = None, - ndim_supp: Optional[int] = None, - ndims_params: Optional[Sequence[int]] = None, - signature: Optional[str] = None, + logp: Callable | None = None, + logcdf: Callable | None = None, + support_point: Callable | None = None, + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, dtype: str = "floatX", class_name: str = "CustomDist", **kwargs, @@ -784,9 +783,9 @@ def rv_op( cls, *dist_params, dist: Callable, - logp: Optional[Callable], - logcdf: Optional[Callable], - support_point: Optional[Callable], + logp: Callable | None, + logcdf: Callable | None, + support_point: Callable | None, size=None, signature: str, class_name: str, @@ -838,7 +837,7 @@ def custom_dist_get_support_point(op, rv, size, *params): *[ p for p in params - if not isinstance(p.type, (RandomType, RandomGeneratorType)) + if not isinstance(p.type, RandomType | RandomGeneratorType) ], ) @@ -1125,15 +1124,15 @@ def __new__( cls, name, *dist_params, - dist: Optional[Callable] = None, - random: Optional[Callable] = None, - logp: Optional[Callable] = None, - logcdf: Optional[Callable] = None, - support_point: Optional[Callable] = None, + dist: Callable | None = None, + random: Callable | None = None, + logp: Callable | None = None, + logcdf: Callable | None = None, + support_point: Callable | None = None, # TODO: Deprecate ndim_supp / ndims_params in favor of signature? - ndim_supp: Optional[int] = None, - ndims_params: Optional[Sequence[int]] = None, - signature: Optional[str] = None, + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, dtype: str = "floatX", **kwargs, ): @@ -1188,14 +1187,14 @@ def __new__( def dist( cls, *dist_params, - dist: Optional[Callable] = None, - random: Optional[Callable] = None, - logp: Optional[Callable] = None, - logcdf: Optional[Callable] = None, - support_point: Optional[Callable] = None, - ndim_supp: Optional[int] = None, - ndims_params: Optional[Sequence[int]] = None, - signature: Optional[str] = None, + dist: Callable | None = None, + random: Callable | None = None, + logp: Callable | None = None, + logcdf: Callable | None = None, + support_point: Callable | None = None, + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, dtype: str = "floatX", **kwargs, ): @@ -1369,7 +1368,7 @@ class PartialObservedRV(SymbolicRandomVariable): def create_partial_observed_rv( rv: TensorVariable, - mask: Union[np.ndarray, TensorVariable], + mask: np.ndarray | TensorVariable, ) -> tuple[ tuple[TensorVariable, TensorVariable], tuple[TensorVariable, TensorVariable], TensorVariable ]: diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 10c2bb14ad..62f3008ac2 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import warnings import numpy as np @@ -178,7 +179,7 @@ class Mixture(Distribution): @classmethod def dist(cls, w, comp_dists, **kwargs): - if not isinstance(comp_dists, (tuple, list)): + if not isinstance(comp_dists, tuple | list): # comp_dists is a single component comp_dists = [comp_dists] elif len(comp_dists) == 1: @@ -204,7 +205,7 @@ def dist(cls, w, comp_dists, **kwargs): # TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them # and resize them if not isinstance(dist, TensorVariable) or not isinstance( - dist.owner.op, (RandomVariable, SymbolicRandomVariable) + dist.owner.op, RandomVariable | SymbolicRandomVariable ): raise ValueError( f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}" @@ -480,7 +481,7 @@ def transform_warning(): transform.backward(value, *component.owner.inputs) for transform, component in zip(default_transforms, components) ] - for expr1, expr2 in zip(backward_expressions[:-1], backward_expressions[1:]): + for expr1, expr2 in itertools.pairwise(backward_expressions): if not equal_computations([expr1], [expr2]): transform_warning() return None diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 956bca276d..b39dfa903b 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -18,7 +18,6 @@ import warnings from functools import partial, reduce -from typing import Optional import numpy as np import pytensor @@ -1084,7 +1083,7 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv def _lkj_normalizing_constant(eta, n): # TODO: This is mixing python branching with the potentially symbolic n and eta variables - if not isinstance(eta, (int, float)): + if not isinstance(eta, int | float): raise NotImplementedError("eta must be an int or float") if not isinstance(n, int): raise NotImplementedError("n must be an integer") @@ -1186,7 +1185,7 @@ def dist(cls, n, eta, sd_dist, **kwargs): if not ( isinstance(sd_dist, Variable) and sd_dist.owner is not None - and isinstance(sd_dist.owner.op, (RandomVariable, SymbolicRandomVariable)) + and isinstance(sd_dist.owner.op, RandomVariable | SymbolicRandomVariable) and sd_dist.owner.op.ndim_supp < 2 ): raise TypeError("sd_dist must be a scalar or vector distribution variable") @@ -2263,7 +2262,7 @@ def logp(value, mu, W, alpha, tau): TensorVariable """ - sparse = isinstance(W, (pytensor.sparse.SparseConstant, pytensor.sparse.SparseVariable)) + sparse = isinstance(W, pytensor.sparse.SparseConstant | pytensor.sparse.SparseVariable) if sparse: D = sp_sum(W, axis=0) @@ -2755,7 +2754,7 @@ def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs): ) @classmethod - def check_zerosum_axes(cls, n_zerosum_axes: Optional[int]) -> int: + def check_zerosum_axes(cls, n_zerosum_axes: int | None) -> int: if n_zerosum_axes is None: n_zerosum_axes = 1 if not isinstance(n_zerosum_axes, int): diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 99cacbd977..1f4f501943 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -22,7 +22,7 @@ from collections.abc import Sequence from functools import singledispatch -from typing import Any, Optional, Union, cast +from typing import Any, TypeAlias, cast import numpy as np @@ -34,7 +34,6 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.shape import SpecifyShape from pytensor.tensor.variable import TensorVariable -from typing_extensions import TypeAlias from pymc.model import modelcontext from pymc.pytensorf import convert_observed_data @@ -177,24 +176,24 @@ def broadcast_dist_samples_shape(shapes, size=None): # User-provided can be lazily specified as scalars -Shape: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, Variable]]] -Dims: TypeAlias = Union[str, Sequence[Optional[str]]] -Size: TypeAlias = Union[int, TensorVariable, Sequence[Union[int, Variable]]] +Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable] +Dims: TypeAlias = str | Sequence[str | None] +Size: TypeAlias = int | TensorVariable | Sequence[int | Variable] # After conversion to vectors -StrongShape: TypeAlias = Union[TensorVariable, tuple[Union[int, Variable], ...]] -StrongDims: TypeAlias = Sequence[Optional[str]] -StrongSize: TypeAlias = Union[TensorVariable, tuple[Union[int, Variable], ...]] +StrongShape: TypeAlias = TensorVariable | tuple[int | Variable, ...] +StrongDims: TypeAlias = Sequence[str | None] +StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...] -def convert_dims(dims: Optional[Dims]) -> Optional[StrongDims]: +def convert_dims(dims: Dims | None) -> StrongDims | None: """Process a user-provided dims variable into None or a valid dims tuple.""" if dims is None: return None if isinstance(dims, str): dims = (dims,) - elif isinstance(dims, (list, tuple)): + elif isinstance(dims, list | tuple): dims = tuple(dims) else: raise ValueError(f"The `dims` parameter must be a tuple, str or list. Actual: {type(dims)}") @@ -202,7 +201,7 @@ def convert_dims(dims: Optional[Dims]) -> Optional[StrongDims]: return dims -def convert_shape(shape: Shape) -> Optional[StrongShape]: +def convert_shape(shape: Shape) -> StrongShape | None: """Process a user-provided shape variable into None or a valid shape object.""" if shape is None: return None @@ -210,7 +209,7 @@ def convert_shape(shape: Shape) -> Optional[StrongShape]: shape = (shape,) elif isinstance(shape, TensorVariable) and shape.ndim == 1: shape = tuple(shape) - elif isinstance(shape, (list, tuple)): + elif isinstance(shape, list | tuple): shape = tuple(shape) else: raise ValueError( @@ -220,7 +219,7 @@ def convert_shape(shape: Shape) -> Optional[StrongShape]: return shape -def convert_size(size: Size) -> Optional[StrongSize]: +def convert_size(size: Size) -> StrongSize | None: """Process a user-provided size variable into None or a valid size object.""" if size is None: return None @@ -228,7 +227,7 @@ def convert_size(size: Size) -> Optional[StrongSize]: size = (size,) elif isinstance(size, TensorVariable) and size.ndim == 1: size = tuple(size) - elif isinstance(size, (list, tuple)): + elif isinstance(size, list | tuple): size = tuple(size) else: raise ValueError( @@ -265,10 +264,10 @@ def shape_from_dims(dims: StrongDims, model) -> StrongShape: def find_size( - shape: Optional[StrongShape], - size: Optional[StrongSize], + shape: StrongShape | None, + size: StrongSize | None, ndim_supp: int, -) -> Optional[StrongSize]: +) -> StrongSize | None: """Determines the size keyword argument for creating a Distribution. Parameters @@ -421,14 +420,14 @@ def change_specify_shape_size(op, ss, new_size, expand) -> TensorVariable: def get_support_shape( - support_shape: Optional[Sequence[Union[int, np.ndarray, TensorVariable]]], + support_shape: Sequence[int | np.ndarray | TensorVariable] | None, *, - shape: Optional[Shape] = None, - dims: Optional[Dims] = None, - observed: Optional[Any] = None, - support_shape_offset: Optional[Sequence[int]] = None, + shape: Shape | None = None, + dims: Dims | None = None, + observed: Any | None = None, + support_shape_offset: Sequence[int] | None = None, ndim_supp: int = 1, -) -> Optional[TensorVariable]: +) -> TensorVariable | None: """Extract the support shapes from shape / dims / observed information Parameters @@ -461,7 +460,7 @@ def get_support_shape( support_shape_offset = [0] * ndim_supp elif isinstance(support_shape_offset, int): support_shape_offset = [support_shape_offset] * ndim_supp - inferred_support_shape: Optional[Sequence[Union[int, np.ndarray, Variable]]] = None + inferred_support_shape: Sequence[int | np.ndarray | Variable] | None = None if shape is not None: shape = to_tuple(shape) @@ -518,13 +517,13 @@ def get_support_shape( def get_support_shape_1d( - support_shape: Optional[Union[int, np.ndarray, TensorVariable]], + support_shape: int | np.ndarray | TensorVariable | None, *, - shape: Optional[Shape] = None, - dims: Optional[Dims] = None, - observed: Optional[Any] = None, + shape: Shape | None = None, + dims: Dims | None = None, + observed: Any | None = None, support_shape_offset: int = 0, -) -> Optional[TensorVariable]: +) -> TensorVariable | None: """Helper function for cases when you just care about one dimension.""" support_shape_tuple = get_support_shape( support_shape=(support_shape,) if support_shape is not None else None, diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 5ff3948458..1412d3e446 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -15,7 +15,7 @@ import warnings from abc import ABCMeta -from typing import Callable, Optional +from collections.abc import Callable import numpy as np import pytensor @@ -88,7 +88,7 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> pt.TensorVari if not ( isinstance(init_dist, pt.TensorVariable) and init_dist.owner is not None - and isinstance(init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)) + and isinstance(init_dist.owner.op, RandomVariable | SymbolicRandomVariable) ): raise TypeError("init_dist must be a distribution variable") check_dist_not_registered(init_dist) @@ -96,7 +96,7 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> pt.TensorVari if not ( isinstance(innovation_dist, pt.TensorVariable) and innovation_dist.owner is not None - and isinstance(innovation_dist.owner.op, (RandomVariable, SymbolicRandomVariable)) + and isinstance(innovation_dist.owner.op, RandomVariable | SymbolicRandomVariable) ): raise TypeError("innovation_dist must be a distribution variable") check_dist_not_registered(innovation_dist) @@ -129,7 +129,7 @@ def get_steps(cls, innovation_dist, steps, shape, dims, observed): if not ( isinstance(innovation_dist, pt.TensorVariable) and innovation_dist.owner is not None - and isinstance(innovation_dist.owner.op, (RandomVariable, SymbolicRandomVariable)) + and isinstance(innovation_dist.owner.op, RandomVariable | SymbolicRandomVariable) ): raise TypeError("innovation_dist must be a distribution variable") @@ -549,7 +549,7 @@ def dist( if init_dist is not None: if not isinstance(init_dist, TensorVariable) or not isinstance( - init_dist.owner.op, (RandomVariable, SymbolicRandomVariable) + init_dist.owner.op, RandomVariable | SymbolicRandomVariable ): raise ValueError( f"Init dist must be a distribution created via the `.dist()` API, " @@ -573,7 +573,7 @@ def dist( return super().dist([rhos, sigma, init_dist, steps, ar_order, constant], **kwargs) @classmethod - def _get_ar_order(cls, rhos: TensorVariable, ar_order: Optional[int], constant: bool) -> int: + def _get_ar_order(cls, rhos: TensorVariable, ar_order: int | None, constant: bool) -> int: """Compute ar_order given inputs If ar_order is not specified we do constant folding on the shape of rhos @@ -948,7 +948,7 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs): if init_dist is not None: if not isinstance(init_dist, TensorVariable) or not isinstance( - init_dist.owner.op, (RandomVariable, SymbolicRandomVariable) + init_dist.owner.op, RandomVariable | SymbolicRandomVariable ): raise ValueError( f"Init dist must be a distribution created via the `.dist()` API, " diff --git a/pymc/func_utils.py b/pymc/func_utils.py index 84cb632337..d101826371 100644 --- a/pymc/func_utils.py +++ b/pymc/func_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Union +from collections.abc import Callable import numpy as np import pytensor.tensor as pt @@ -30,8 +30,8 @@ def find_constrained_prior( upper: float, init_guess: dict[str, float], mass: float = 0.95, - fixed_params: Optional[dict[str, float]] = None, - mass_below_lower: Optional[float] = None, + fixed_params: dict[str, float] | None = None, + mass_below_lower: float | None = None, **kwargs, ) -> dict[str, float]: """ @@ -165,8 +165,8 @@ def find_constrained_prior( constraint = pt.exp(logcdf_upper) - pt.exp(logcdf_lower) constraint_fn = pm.pytensorf.compile_pymc([dist_params], constraint, allow_input_downcast=True) - jac: Union[str, Callable] - constraint_jac: Union[str, Callable] + jac: str | Callable + constraint_jac: str | Callable try: pytensor_jac = pm.gradient(target, [dist_params]) jac = pm.pytensorf.compile_pymc([dist_params], pytensor_jac, allow_input_downcast=True) diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index 0eddaea932..d7f5c66569 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -16,10 +16,10 @@ import warnings from collections import Counter -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import reduce from operator import add, mul -from typing import Any, Callable, Optional, Union +from typing import Any import numpy as np import pytensor.tensor as pt @@ -51,8 +51,8 @@ from pymc.pytensorf import constant_fold -TensorLike = Union[np.ndarray, TensorVariable] -IntSequence = Union[np.ndarray, Sequence[int]] +TensorLike = np.ndarray | TensorVariable +IntSequence = np.ndarray | Sequence[int] class BaseCovariance: @@ -63,7 +63,7 @@ class BaseCovariance: def __call__( self, X: TensorLike, - Xs: Optional[TensorLike] = None, + Xs: TensorLike | None = None, diag: bool = False, ) -> TensorVariable: r""" @@ -86,7 +86,7 @@ def __call__( def diag(self, X: TensorLike) -> TensorVariable: raise NotImplementedError - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: raise NotImplementedError def __add__(self, other) -> "Add": @@ -165,7 +165,7 @@ class Covariance(BaseCovariance): function operates on. """ - def __init__(self, input_dim: int, active_dims: Optional[IntSequence] = None): + def __init__(self, input_dim: int, active_dims: IntSequence | None = None): self.input_dim = input_dim if active_dims is None: self.active_dims = np.arange(input_dim) @@ -256,11 +256,7 @@ def _merge_factors_cov(self, X, Xs=None, diag=False): elif isinstance( factor, - ( - TensorConstant, - TensorVariable, - TensorSharedVariable, - ), + TensorConstant | TensorVariable | TensorSharedVariable, ): if factor.ndim == 2 and diag: factor_list.append(pt.diag(factor)) @@ -318,7 +314,7 @@ class Add(Combination): def __call__( self, X: TensorLike, - Xs: Optional[TensorLike] = None, + Xs: TensorLike | None = None, diag: bool = False, ) -> TensorVariable: return reduce(add, self._merge_factors_cov(X, Xs, diag)) @@ -331,7 +327,7 @@ class Prod(Combination): def __call__( self, X: TensorLike, - Xs: Optional[TensorLike] = None, + Xs: TensorLike | None = None, diag: bool = False, ) -> TensorVariable: return reduce(mul, self._merge_factors_cov(X, Xs, diag)) @@ -353,7 +349,7 @@ def __init__(self, kernel: Covariance, power): super().__init__(input_dim=self.kernel.input_dim, active_dims=self.kernel.active_dims) def __call__( - self, X: TensorLike, Xs: Optional[TensorLike] = None, diag: bool = False + self, X: TensorLike, Xs: TensorLike | None = None, diag: bool = False ) -> TensorVariable: return self.kernel(X, Xs, diag=diag) ** self.power @@ -390,7 +386,7 @@ def _split(self, X, Xs): return X_split, Xs_split def __call__( - self, X: TensorLike, Xs: Optional[TensorLike] = None, diag: bool = False + self, X: TensorLike, Xs: TensorLike | None = None, diag: bool = False ) -> TensorVariable: X_split, Xs_split = self._split(X, Xs) covs = [cov(x, xs, diag) for cov, x, xs in zip(self._factor_list, X_split, Xs_split)] @@ -412,7 +408,7 @@ def __init__(self, c): def diag(self, X: TensorLike) -> TensorVariable: return self._alloc(self.c, X.shape[0]) - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: if Xs is None: return self._alloc(self.c, X.shape[0], X.shape[0]) else: @@ -434,7 +430,7 @@ def __init__(self, sigma): def diag(self, X: TensorLike) -> TensorVariable: return self._alloc(pt.square(self.sigma), X.shape[0]) - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: if Xs is None: return pt.diag(self.diag(X)) else: @@ -478,7 +474,7 @@ def __init__( input_dim: int, period, tau=4, - active_dims: Optional[IntSequence] = None, + active_dims: IntSequence | None = None, ): super().__init__(input_dim, active_dims) self.c = pt.as_tensor_variable(period / 2) @@ -494,7 +490,7 @@ def dist(self, X, Xs): def weinland(self, t): return (1 + self.tau * t / self.c) * pt.clip(1 - t / self.c, 0, np.inf) ** self.tau - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: X, Xs = self._slice(X, Xs) return self.weinland(self.dist(X, Xs)) @@ -518,13 +514,13 @@ def __init__( input_dim: int, ls=None, ls_inv=None, - active_dims: Optional[IntSequence] = None, + active_dims: IntSequence | None = None, ): super().__init__(input_dim, active_dims) if (ls is None and ls_inv is None) or (ls is not None and ls_inv is not None): raise ValueError("Only one of 'ls' or 'ls_inv' must be provided") elif ls_inv is not None: - if isinstance(ls_inv, (list, tuple)): + if isinstance(ls_inv, list | tuple): ls = 1.0 / np.asarray(ls_inv) else: ls = 1.0 / ls_inv @@ -555,7 +551,7 @@ def _sqrt(self, r2): def diag(self, X: TensorLike) -> TensorVariable: return self._alloc(1.0, X.shape[0]) - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: X, Xs = self._slice(X, Xs) r2 = self.square_dist(X, Xs) return self.full_from_distance(r2, squared=True) @@ -613,7 +609,7 @@ def __init__( alpha, ls=None, ls_inv=None, - active_dims: Optional[IntSequence] = None, + active_dims: IntSequence | None = None, ): super().__init__(input_dim, ls, ls_inv, active_dims) self.alpha = alpha @@ -771,12 +767,12 @@ def __init__( period, ls=None, ls_inv=None, - active_dims: Optional[IntSequence] = None, + active_dims: IntSequence | None = None, ): super().__init__(input_dim, ls, ls_inv, active_dims) self.period = period - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: X, Xs = self._slice(X, Xs) if Xs is None: Xs = X @@ -823,7 +819,7 @@ class Linear(Covariance): k(x, x') = (x - c)(x' - c) """ - def __init__(self, input_dim: int, c, active_dims: Optional[IntSequence] = None): + def __init__(self, input_dim: int, c, active_dims: IntSequence | None = None): super().__init__(input_dim, active_dims) self.c = c @@ -832,7 +828,7 @@ def _common(self, X, Xs=None): Xc = pt.sub(X, self.c) return X, Xc, Xs - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: X, Xc, Xs = self._common(X, Xs) if Xs is None: return pt.dot(Xc, pt.transpose(Xc)) @@ -853,12 +849,12 @@ class Polynomial(Linear): k(x, x') = [(x - c)(x' - c) + \mathrm{offset}]^{d} """ - def __init__(self, input_dim: int, c, d, offset, active_dims: Optional[IntSequence] = None): + def __init__(self, input_dim: int, c, d, offset, active_dims: IntSequence | None = None): super().__init__(input_dim, c, active_dims) self.d = d self.offset = offset - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: linear = super().full(X, Xs) return pt.power(linear + self.offset, self.d) @@ -890,7 +886,7 @@ def __init__( cov_func: Covariance, warp_func: Callable, args=None, - active_dims: Optional[IntSequence] = None, + active_dims: IntSequence | None = None, ): super().__init__(input_dim, active_dims) if not callable(warp_func): @@ -901,7 +897,7 @@ def __init__( self.args = args self.cov_func = cov_func - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: X, Xs = self._slice(X, Xs) if Xs is None: return self.cov_func(self.w(X, self.args), Xs) @@ -965,7 +961,7 @@ def __init__(self, cov_func: Stationary, period): self.cov_func = cov_func self.period = period - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: X, Xs = self._slice(X, Xs) if Xs is None: Xs = X @@ -1002,7 +998,7 @@ def __init__( input_dim: int, lengthscale_func: Callable, args=None, - active_dims: Optional[IntSequence] = None, + active_dims: IntSequence | None = None, ): super().__init__(input_dim, active_dims) if active_dims is not None: @@ -1029,7 +1025,7 @@ def square_dist(self, X, Xs=None): ) return pt.clip(sqd, 0.0, np.inf) - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: X, Xs = self._slice(X, Xs) rx = self.lfunc(pt.as_tensor_variable(X), self.args) if Xs is None: @@ -1071,7 +1067,7 @@ def __init__( cov_func: Covariance, scaling_func: Callable, args=None, - active_dims: Optional[IntSequence] = None, + active_dims: IntSequence | None = None, ): super().__init__(input_dim, active_dims) if not callable(scaling_func): @@ -1088,7 +1084,7 @@ def diag(self, X: TensorLike) -> TensorVariable: scf_diag = pt.square(pt.flatten(self.scaling_func(X, self.args))) return cov_diag * scf_diag - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: X, Xs = self._slice(X, Xs) scf_x = self.scaling_func(X, self.args) if Xs is None: @@ -1137,7 +1133,7 @@ def __init__( W=None, kappa=None, B=None, - active_dims: Optional[IntSequence] = None, + active_dims: IntSequence | None = None, ): super().__init__(input_dim, active_dims) if len(self.active_dims) != 1: @@ -1154,7 +1150,7 @@ def __init__( else: raise ValueError("Exactly one of (W, kappa) and B must be provided to Coregion") - def full(self, X: TensorLike, Xs: Optional[TensorLike] = None) -> TensorVariable: + def full(self, X: TensorLike, Xs: TensorLike | None = None) -> TensorVariable: X, Xs = self._slice(X, Xs) index = pt.cast(X, "int32") if Xs is None: diff --git a/pymc/gp/hsgp_approx.py b/pymc/gp/hsgp_approx.py index 2778fd370e..68255a7f98 100644 --- a/pymc/gp/hsgp_approx.py +++ b/pymc/gp/hsgp_approx.py @@ -17,7 +17,6 @@ from collections.abc import Sequence from types import ModuleType -from typing import Optional, Union import numpy as np import pytensor.tensor as pt @@ -28,10 +27,10 @@ from pymc.gp.gp import Base from pymc.gp.mean import Mean, Zero -TensorLike = Union[np.ndarray, pt.TensorVariable] +TensorLike = np.ndarray | pt.TensorVariable -def set_boundary(Xs: TensorLike, c: Union[numbers.Real, TensorLike]) -> TensorLike: +def set_boundary(Xs: TensorLike, c: numbers.Real | TensorLike) -> TensorLike: """Set the boundary using the mean-subtracted `Xs` and `c`. `c` is usually a scalar multiplyer greater than 1.0, but it may be one value per dimension or column of `Xs`. """ @@ -176,10 +175,10 @@ class HSGP(Base): def __init__( self, m: Sequence[int], - L: Optional[Sequence[float]] = None, - c: Optional[numbers.Real] = None, + L: Sequence[float] | None = None, + c: numbers.Real | None = None, drop_first: bool = False, - parameterization: Optional[str] = "noncentered", + parameterization: str | None = "noncentered", *, mean_func: Mean = Zero(), cov_func: Covariance, @@ -220,7 +219,7 @@ def __init__( self._drop_first = drop_first self._m = m self._m_star = int(np.prod(self._m)) - self._L: Optional[pt.TensorVariable] = None + self._L: pt.TensorVariable | None = None if L is not None: self._L = pt.as_tensor(L) self._c = c @@ -328,7 +327,7 @@ def prior_linearized(self, Xs: TensorLike): # If not provided, use Xs and c to set L if self._L is None: - assert isinstance(self._c, (numbers.Real, np.ndarray, pt.TensorVariable)) + assert isinstance(self._c, numbers.Real | np.ndarray | pt.TensorVariable) self.L = pt.as_tensor(set_boundary(Xs, self._c)) else: self.L = self._L @@ -341,7 +340,7 @@ def prior_linearized(self, Xs: TensorLike): i = int(self._drop_first is True) return phi[:, i:], pt.sqrt(psd[i:]) - def prior(self, name: str, X: TensorLike, dims: Optional[str] = None): # type: ignore + def prior(self, name: str, X: TensorLike, dims: str | None = None): # type: ignore R""" Returns the (approximate) GP prior distribution evaluated over the input locations `X`. For usage examples, refer to `pm.gp.Latent`. @@ -396,7 +395,7 @@ def _build_conditional(self, Xnew): elif self._parameterization == "centered": return self.mean_func(Xnew) + phi[:, i:] @ beta - def conditional(self, name: str, Xnew: TensorLike, dims: Optional[str] = None): # type: ignore + def conditional(self, name: str, Xnew: TensorLike, dims: str | None = None): # type: ignore R""" Returns the (approximate) conditional distribution evaluated over new input locations `Xnew`. @@ -478,7 +477,7 @@ class HSGPPeriodic(Base): def __init__( self, m: int, - scale: Optional[Union[float, TensorLike]] = 1.0, + scale: float | TensorLike | None = 1.0, *, mean_func: Mean = Zero(), cov_func: Periodic, @@ -589,7 +588,7 @@ def prior_linearized(self, Xs: TensorLike): psd = self.scale * self.cov_func.power_spectral_density_approx(J) return (phi_cos, phi_sin), psd - def prior(self, name: str, X: TensorLike, dims: Optional[str] = None): # type: ignore + def prior(self, name: str, X: TensorLike, dims: str | None = None): # type: ignore R""" Returns the (approximate) GP prior distribution evaluated over the input locations `X`. For usage examples, refer to `pm.gp.Latent`. @@ -640,7 +639,7 @@ def _build_conditional(self, Xnew): phi = phi_cos @ (psd * beta[:m]) + phi_sin[..., 1:] @ (psd[1:] * beta[m:]) return self.mean_func(Xnew) + phi - def conditional(self, name: str, Xnew: TensorLike, dims: Optional[str] = None): # type: ignore + def conditional(self, name: str, Xnew: TensorLike, dims: str | None = None): # type: ignore R""" Returns the (approximate) conditional distribution evaluated over new input locations `Xnew`. diff --git a/pymc/gp/util.py b/pymc/gp/util.py index 3f829ab002..ba20130a3d 100644 --- a/pymc/gp/util.py +++ b/pymc/gp/util.py @@ -113,7 +113,7 @@ def kmeans_inducing_points(n_inducing, X, **kmeans_kwargs): # first whiten X if isinstance(X, TensorConstant): X = X.value - elif isinstance(X, (np.ndarray, tuple, list)): + elif isinstance(X, np.ndarray | tuple | list): X = np.asarray(X) else: raise TypeError( diff --git a/pymc/initial_point.py b/pymc/initial_point.py index f9d7855bbc..2e06f51f52 100644 --- a/pymc/initial_point.py +++ b/pymc/initial_point.py @@ -14,8 +14,7 @@ import functools import warnings -from collections.abc import Sequence -from typing import Callable, Optional, Union +from collections.abc import Callable, Sequence import numpy as np import pytensor @@ -29,13 +28,13 @@ from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name -StartDict = dict[Union[Variable, str], Union[np.ndarray, Variable, str]] +StartDict = dict[Variable | str, np.ndarray | Variable | str] PointType = dict[str, np.ndarray] def convert_str_to_rv_dict( model, start: StartDict -) -> dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]: +) -> dict[TensorVariable, np.ndarray | Variable | str | None]: """Helper function for converting a user-provided start dict with str keys of (transformed) variable names to a dict mapping the RV tensors to untransformed initvals. TODO: Deprecate this functionality and only accept TensorVariables as keys @@ -56,8 +55,8 @@ def convert_str_to_rv_dict( def make_initial_point_fns_per_chain( *, model, - overrides: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], - jitter_rvs: Optional[set[TensorVariable]] = None, + overrides: StartDict | Sequence[StartDict | None] | None, + jitter_rvs: set[TensorVariable] | None = None, chains: int, ) -> list[Callable]: """Create an initial point function for each chain, as defined by initvals @@ -112,8 +111,8 @@ def make_initial_point_fns_per_chain( def make_initial_point_fn( *, model, - overrides: Optional[StartDict] = None, - jitter_rvs: Optional[set[TensorVariable]] = None, + overrides: StartDict | None = None, + jitter_rvs: set[TensorVariable] | None = None, default_strategy: str = "support_point", return_transformed: bool = True, ) -> Callable: @@ -179,8 +178,8 @@ def make_initial_point_expression( *, free_rvs: Sequence[TensorVariable], rvs_to_transforms: dict[TensorVariable, Transform], - initval_strategies: dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]], - jitter_rvs: Optional[set[TensorVariable]] = None, + initval_strategies: dict[TensorVariable, np.ndarray | Variable | str | None], + jitter_rvs: set[TensorVariable] | None = None, default_strategy: str = "support_point", return_transformed: bool = False, ) -> list[TensorVariable]: diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index f35ab4c523..41c92e422d 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -71,7 +71,7 @@ def _logprob_helper(rv, *values, **kwargs): if (not name) and (len(values) == 1): name = values[0].name if name: - if isinstance(logprob, (list, tuple)): + if isinstance(logprob, list | tuple): for i, term in enumerate(logprob): term.name = f"{name}_logprob.{i}" else: diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 446ef59355..52cc97b193 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -38,7 +38,7 @@ from collections import deque from collections.abc import Sequence -from typing import Optional, Union +from typing import TypeAlias import numpy as np import pytensor.tensor as pt @@ -54,7 +54,6 @@ from pytensor.graph.op import compute_test_value from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter from pytensor.tensor.variable import TensorVariable -from typing_extensions import TypeAlias from pymc.logprob.abstract import ( MeasurableVariable, @@ -69,7 +68,7 @@ from pymc.logprob.utils import rvs_in_graph from pymc.pytensorf import replace_vars_in_graphs -TensorLike: TypeAlias = Union[Variable, float, np.ndarray] +TensorLike: TypeAlias = Variable | float | np.ndarray def _find_unallowed_rvs_in_graph(graph): @@ -79,11 +78,11 @@ def _find_unallowed_rvs_in_graph(graph): return { rv for rv in rvs_in_graph(graph) - if not isinstance(rv.owner.op, (SimulatorRV, MinibatchIndexRV)) + if not isinstance(rv.owner.op, SimulatorRV | MinibatchIndexRV) } -def _warn_rvs_in_inferred_graph(graph: Union[TensorVariable, Sequence[TensorVariable]]): +def _warn_rvs_in_inferred_graph(graph: TensorVariable | Sequence[TensorVariable]): """Issue warning if any RVs are found in graph. RVs are usually an (implicit) conditional input of the derived probability expression, @@ -410,8 +409,8 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens def conditional_logp( rv_values: dict[TensorVariable, TensorVariable], warn_rvs=None, - ir_rewriter: Optional[GraphRewriter] = None, - extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None, + ir_rewriter: GraphRewriter | None = None, + extra_rewrites: GraphRewriter | NodeRewriter | None = None, **kwargs, ) -> dict[TensorVariable, TensorVariable]: r"""Create a map between variables and conditional log-probabilities @@ -546,7 +545,7 @@ def conditional_logp( **kwargs, ) - if not isinstance(q_logprob_vars, (list, tuple)): + if not isinstance(q_logprob_vars, list | tuple): q_logprob_vars = [q_logprob_vars] for q_value_var, q_logprob_var in zip(q_values, q_logprob_vars): diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index f5d8cf848c..df37e33782 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import numpy as np import pytensor.tensor as pt @@ -40,10 +39,8 @@ class MeasurableComparison(MeasurableElemwise): @node_rewriter(tracks=[gt, lt, ge, le]) -def find_measurable_comparisons( - fgraph: FunctionGraph, node: Node -) -> Optional[list[TensorVariable]]: - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) +def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -105,9 +102,9 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs): condn_exp = pt.eq(value, np.array(True)) - if isinstance(op.scalar_op, (GT, GE)): + if isinstance(op.scalar_op, GT | GE): logprob = pt.switch(condn_exp, logccdf, logcdf) - elif isinstance(op.scalar_op, (LT, LE)): + elif isinstance(op.scalar_op, LT | LE): logprob = pt.switch(condn_exp, logcdf, logccdf) else: raise TypeError(f"Unsupported scalar_op {op.scalar_op}") @@ -134,8 +131,8 @@ class MeasurableBitwise(MeasurableElemwise): @node_rewriter(tracks=[invert]) -def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) +def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index b9221e08db..d582da0799 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -34,7 +34,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Optional import numpy as np import pytensor.tensor as pt @@ -63,10 +62,10 @@ class MeasurableClip(MeasurableElemwise): @node_rewriter(tracks=[clip]) -def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: +def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -158,8 +157,8 @@ class MeasurableRound(MeasurableElemwise): @node_rewriter(tracks=[ceil, floor, round_half_to_even]) -def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) +def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 1cf202ec5e..f7b483e599 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -34,7 +34,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Optional import pytensor.tensor as pt @@ -64,13 +63,13 @@ def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): @node_rewriter([SpecifyShape]) -def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable]]: +def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None: r"""Finds `SpecifyShapeOp`\s for which a `logprob` can be computed.""" if isinstance(node.op, MeasurableSpecifyShape): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -117,13 +116,13 @@ def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): @node_rewriter([CheckAndRaise]) -def find_measurable_check_and_raise(fgraph, node) -> Optional[list[TensorVariable]]: +def find_measurable_check_and_raise(fgraph, node) -> list[TensorVariable] | None: r"""Finds `AssertOp`\s for which a `logprob` can be computed.""" if isinstance(node.op, MeasurableCheckAndRaise): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index 810f226c8b..3c1c9d3e72 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -34,7 +34,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Optional import pytensor.tensor as pt @@ -78,7 +77,7 @@ def logprob_cumsum(op, values, base_rv, **kwargs): @node_rewriter([CumOp]) -def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]: +def find_measurable_cumsums(fgraph, node) -> list[TensorVariable] | None: r"""Finds `Cumsums`\s for which a `logprob` can be computed.""" if not (isinstance(node.op, CumOp) and node.op.mode == "add"): @@ -87,7 +86,7 @@ def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]: if isinstance(node.op, MeasurableCumsum): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 011ce5e5fe..08e102f805 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -34,7 +34,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Optional, Union, cast +from typing import cast import pytensor import pytensor.tensor as pt @@ -87,7 +87,7 @@ def is_newaxis(x): def expand_indices( - indices: tuple[Optional[Union[Variable, slice]], ...], shape: tuple[TensorVariable] + indices: tuple[Variable | slice | None, ...], shape: tuple[TensorVariable] ) -> tuple[TensorVariable]: """Convert basic and/or advanced indices into a single, broadcasted advanced indexing operation. @@ -240,7 +240,7 @@ def perform(self, node, inputs, outputs): def get_stack_mixture_vars( node: Apply, -) -> tuple[Optional[list[TensorVariable]], Optional[int]]: +) -> tuple[list[TensorVariable] | None, int | None]: r"""Extract the mixture terms from a `*Subtensor*` applied to stacked `MeasurableVariable`\s.""" assert isinstance(node.op, subtensor_ops) @@ -248,7 +248,7 @@ def get_stack_mixture_vars( joined_rvs = node.inputs[0] # First, make sure that it's some sort of concatenation - if not (joined_rvs.owner and isinstance(joined_rvs.owner.op, (MakeVector, Join))): + if not (joined_rvs.owner and isinstance(joined_rvs.owner.op, MakeVector | Join)): return None, None if isinstance(joined_rvs.owner.op, MakeVector): @@ -276,7 +276,7 @@ def find_measurable_index_mixture(fgraph, node): From these terms, new terms ``Z_rv[i] = mixture_comps[i][i == I_rv]`` are created for each ``i`` in ``enumerate(mixture_comps)``. """ - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -284,7 +284,7 @@ def find_measurable_index_mixture(fgraph, node): mixing_indices = node.inputs[1:] # TODO: Add check / test case for Advanced Boolean indexing - if isinstance(node.op, (AdvancedSubtensor, AdvancedSubtensor1)): + if isinstance(node.op, AdvancedSubtensor | AdvancedSubtensor1): # We don't support (non-scalar) integer array indexing as it can pick repeated values, # but the Mixture logprob assumes all mixture values are independent if any( @@ -298,7 +298,7 @@ def find_measurable_index_mixture(fgraph, node): mixture_rvs, join_axis = get_stack_mixture_vars(node) # We don't support symbolic join axis - if mixture_rvs is None or not isinstance(join_axis, (NoneTypeT, Constant)): + if mixture_rvs is None or not isinstance(join_axis, NoneTypeT | Constant): return None if rv_map_feature.request_measurable(mixture_rvs) != mixture_rvs: @@ -326,9 +326,7 @@ def find_measurable_index_mixture(fgraph, node): @_logprob.register(MixtureRV) -def logprob_MixtureRV( - op, values, *inputs: Optional[Union[TensorVariable, slice]], name=None, **kwargs -): +def logprob_MixtureRV(op, values, *inputs: TensorVariable | slice | None, name=None, **kwargs): (value,) = values join_axis = cast(Variable, inputs[0]) @@ -408,7 +406,7 @@ class MeasurableSwitchMixture(MeasurableElemwise): @node_rewriter([switch]) def find_measurable_switch_mixture(fgraph, node): - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -499,7 +497,7 @@ def useless_ifelse_outputs(fgraph, node): @node_rewriter([IfElse]) def find_measurable_ifelse_mixture(fgraph, node): - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 0dc78d0b0d..b46562c82d 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -34,7 +34,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Optional import pytensor.tensor as pt @@ -73,7 +72,7 @@ class MeasurableMaxDiscrete(Max): @node_rewriter([Max]) -def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: +def find_measurable_max(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -174,7 +173,7 @@ class MeasurableDiscreteMaxNeg(Max): @node_rewriter(tracks=[Max]) -def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: +def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index 3e202ae45a..055516d197 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -37,7 +37,6 @@ from collections import deque from collections.abc import Sequence -from typing import Optional import pytensor.tensor as pt @@ -101,7 +100,7 @@ class MeasurableEquilibriumGraphRewriter(EquilibriumGraphRewriter): """ def apply(self, fgraph): - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if not rv_map_feature: return None @@ -232,7 +231,7 @@ def update_rv_maps( self, old_rv: TensorVariable, new_value: TensorVariable, - new_rv: Optional[TensorVariable] = None, + new_rv: TensorVariable | None = None, ): """Update mappings for a random variable. @@ -333,7 +332,7 @@ def incsubtensor_rv_replace(fgraph, node): This provides a means of specifying "missing data", for instance. """ - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -399,7 +398,7 @@ def incsubtensor_rv_replace(fgraph, node): def construct_ir_fgraph( rv_values: dict[Variable, Variable], - ir_rewriter: Optional[GraphRewriter] = None, + ir_rewriter: GraphRewriter | None = None, ) -> tuple[FunctionGraph, dict[Variable, Variable], dict[Variable, Variable]]: r"""Construct a `FunctionGraph` in measurable IR form for the keys in `rv_values`. diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 44ac31a0c3..84b2722b1a 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -34,9 +34,9 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from collections.abc import Iterable +from collections.abc import Callable, Iterable from copy import copy -from typing import Callable, Optional, cast +from typing import cast import numpy as np import pytensor @@ -365,7 +365,7 @@ def find_measurable_scans(fgraph, node): if not hasattr(fgraph, "shape_feature"): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 9cbf456b7b..2c1443d293 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -34,7 +34,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Optional, Union import pytensor @@ -89,7 +88,7 @@ def naive_bcast_rv_lift(fgraph, node): return None # pragma: no cover # Do not replace RV if it is associated with a value variable - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is not None and rv_var in rv_map_feature.rv_values: return None @@ -198,10 +197,10 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs): @node_rewriter([MakeVector, Join]) -def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]: +def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None: r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed.""" - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -247,7 +246,7 @@ def logprob_dimshuffle(op, values, base_var, **kwargs): # Reverse the effects of dimshuffle on the value variable # First, drop any augmented dimensions and reinsert any dropped dimensions - undo_ds: list[Union[int, str]] = [i for i, o in enumerate(op.new_order) if o != "x"] + undo_ds: list[int | str] = [i for i, o in enumerate(op.new_order) if o != "x"] dropped_dims = tuple(sorted(set(op.transposition) - set(op.shuffle))) for dropped_dim in dropped_dims: undo_ds.insert(dropped_dim, "x") @@ -272,10 +271,10 @@ def logprob_dimshuffle(op, values, base_var, **kwargs): @node_rewriter([DimShuffle]) -def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]: +def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None: r"""Finds `Dimshuffle`\s for which a `logprob` can be computed.""" - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index 966d4b069a..ee043219cc 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Sequence -from typing import Optional, Union import numpy as np @@ -145,7 +144,7 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs) @node_rewriter(tracks=None) -def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]]: +def transform_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None: """Apply transforms to value variables. It is assumed that the input value variables correspond to forward @@ -157,8 +156,8 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply] ``Y`` on the natural scale. """ - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - values_to_transforms: Optional[TransformValuesMapping] = getattr( + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) + values_to_transforms: TransformValuesMapping | None = getattr( fgraph, "values_to_transforms", None ) @@ -213,7 +212,7 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply] @node_rewriter(tracks=[Scan]) -def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[Apply]]: +def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None: """Apply transforms to Scan value variables. This specialized rewrite is needed because Scan replaces the original value variables @@ -221,8 +220,8 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A in this subgraph, leaving the rest intact """ - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - values_to_transforms: Optional[TransformValuesMapping] = getattr( + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) + values_to_transforms: TransformValuesMapping | None = getattr( fgraph, "values_to_transforms", None ) @@ -320,7 +319,7 @@ class TransformValuesRewrite(GraphRewriter): def __init__( self, - values_to_transforms: dict[TensorVariable, Union[Transform, None]], + values_to_transforms: dict[TensorVariable, Transform | None], ): """ Parameters diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 3702a97550..ed0cfd7960 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -35,7 +35,7 @@ # SOFTWARE. import abc -from typing import Callable, Optional, Union +from collections.abc import Callable import numpy as np import pytensor.tensor as pt @@ -134,7 +134,7 @@ def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: @abc.abstractmethod def backward( self, value: TensorVariable, *inputs: Variable - ) -> Union[TensorVariable, tuple[TensorVariable, ...]]: + ) -> TensorVariable | tuple[TensorVariable, ...]: """Invert the transformation. Multiple values may be returned when the transformation is not 1-to-1""" @@ -423,14 +423,14 @@ def measurable_power_exponent_to_exp(fgraph, node): erfcx, ] ) -def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[list[Node]]: +def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node] | None: """Find measurable transformations from Elemwise operators.""" # Node was already converted if isinstance(node.op, MeasurableVariable): return None # pragma: no cover - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -779,7 +779,7 @@ class PowerTransform(Transform): name = "power" def __init__(self, power=None): - if not isinstance(power, (int, float)): + if not isinstance(power, int | float): raise TypeError(f"Power must be integer or float, got {type(power)}") if power == 0: raise ValueError("Power cannot be 0") @@ -821,7 +821,7 @@ def log_jac_det(self, value, *inputs): class IntervalTransform(Transform): name = "interval" - def __init__(self, args_fn: Callable[..., tuple[Optional[Variable], Optional[Variable]]]): + def __init__(self, args_fn: Callable[..., tuple[Variable | None, Variable | None]]): """ Parameters diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 49827f7a61..4c2db767c1 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -37,7 +37,6 @@ import warnings from collections.abc import Container, Sequence -from typing import Optional, Union import numpy as np import pytensor @@ -68,7 +67,7 @@ def replace_rvs_by_values( graphs: Sequence[TensorVariable], *, rvs_to_values: dict[TensorVariable, TensorVariable], - rvs_to_transforms: Optional[dict[TensorVariable, "Transform"]] = None, + rvs_to_transforms: dict[TensorVariable, "Transform"] | None = None, ) -> list[TensorVariable]: """Clone and replace random variables in graphs with their value variables. @@ -132,7 +131,7 @@ def populate_replacements(var): return replace_vars_in_graphs(graphs, replacements) -def rvs_in_graph(vars: Union[Variable, Sequence[Variable]]) -> set[Variable]: +def rvs_in_graph(vars: Variable | Sequence[Variable]) -> set[Variable]: """Assert that there are no `MeasurableVariable` nodes in a graph.""" def expand(r): @@ -148,7 +147,7 @@ def expand(r): return { node for node in walk(makeiter(vars), expand, False) - if node.owner and isinstance(node.owner.op, (RandomVariable, MeasurableVariable)) + if node.owner and isinstance(node.owner.op, RandomVariable | MeasurableVariable) } diff --git a/pymc/model/core.py b/pymc/model/core.py index 0ca9e57c26..cac340f7f4 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -18,16 +18,14 @@ import types import warnings -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from sys import modules from typing import ( TYPE_CHECKING, Any, - Callable, Literal, Optional, TypeVar, - Union, cast, ) @@ -132,18 +130,18 @@ def __exit__(self, typ, value, traceback): # FIXME: is there a more elegant way to automatically add methods to the class that # are instance methods instead of class methods? - def __init__(cls, name, bases, nmspc, context_class: Optional[type] = None, **kwargs): + def __init__(cls, name, bases, nmspc, context_class: type | None = None, **kwargs): """Add ``__enter__`` and ``__exit__`` methods to the new class automatically.""" if context_class is not None: cls._context_class = context_class super().__init__(name, bases, nmspc) - def get_context(cls, error_if_none=True, allow_block_model_access=False) -> Optional[T]: + def get_context(cls, error_if_none=True, allow_block_model_access=False) -> T | None: """Return the most recently pushed context object of type ``cls`` on the stack, or ``None``. If ``error_if_none`` is True (default), raise a ``TypeError`` instead of returning ``None``.""" try: - candidate: Optional[T] = cls.get_contexts()[-1] + candidate: T | None = cls.get_contexts()[-1] except IndexError: # Calling code expects to get a TypeError if the entity # is unfound, and there's too much to fix. @@ -184,7 +182,7 @@ def get_contexts(cls) -> list[T]: # than a class. @property def context_class(cls) -> type: - def resolve_type(c: Union[type, str]) -> type: + def resolve_type(c: type | str) -> type: if isinstance(c, str): c = getattr(modules[cls.__module__], c) if isinstance(c, type): @@ -194,7 +192,7 @@ def resolve_type(c: Union[type, str]) -> type: assert cls is not None if isinstance(cls._context_class, str): cls._context_class = resolve_type(cls._context_class) - if not isinstance(cls._context_class, (str, type)): + if not isinstance(cls._context_class, str | type): raise ValueError( f"Context class for {cls.__name__}, {cls._context_class}, is not of the right type" ) @@ -615,7 +613,7 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs): def compile_logp( self, - vars: Optional[Union[Variable, Sequence[Variable]]] = None, + vars: Variable | Sequence[Variable] | None = None, jacobian: bool = True, sum: bool = True, **compile_kwargs, @@ -637,7 +635,7 @@ def compile_logp( def compile_dlogp( self, - vars: Optional[Union[Variable, Sequence[Variable]]] = None, + vars: Variable | Sequence[Variable] | None = None, jacobian: bool = True, **compile_kwargs, ) -> PointFunc: @@ -655,7 +653,7 @@ def compile_dlogp( def compile_d2logp( self, - vars: Optional[Union[Variable, Sequence[Variable]]] = None, + vars: Variable | Sequence[Variable] | None = None, jacobian: bool = True, **compile_kwargs, ) -> PointFunc: @@ -673,10 +671,10 @@ def compile_d2logp( def logp( self, - vars: Optional[Union[Variable, Sequence[Variable]]] = None, + vars: Variable | Sequence[Variable] | None = None, jacobian: bool = True, sum: bool = True, - ) -> Union[Variable, list[Variable]]: + ) -> Variable | list[Variable]: """Elemwise log-probability of the model. Parameters @@ -697,7 +695,7 @@ def logp( varlist: list[TensorVariable] if vars is None: varlist = self.free_RVs + self.observed_RVs + self.potentials - elif not isinstance(vars, (list, tuple)): + elif not isinstance(vars, list | tuple): varlist = [vars] else: varlist = cast(list[TensorVariable], vars) @@ -752,7 +750,7 @@ def logp( def dlogp( self, - vars: Optional[Union[Variable, Sequence[Variable]]] = None, + vars: Variable | Sequence[Variable] | None = None, jacobian: bool = True, ) -> Variable: """Gradient of the models log-probability w.r.t. ``vars``. @@ -772,7 +770,7 @@ def dlogp( if vars is None: value_vars = None else: - if not isinstance(vars, (list, tuple)): + if not isinstance(vars, list | tuple): vars = [vars] value_vars = [] @@ -791,7 +789,7 @@ def dlogp( def d2logp( self, - vars: Optional[Union[Variable, Sequence[Variable]]] = None, + vars: Variable | Sequence[Variable] | None = None, jacobian: bool = True, ) -> Variable: """Hessian of the models log-probability w.r.t. ``vars``. @@ -811,7 +809,7 @@ def d2logp( if vars is None: value_vars = None else: - if not isinstance(vars, (list, tuple)): + if not isinstance(vars, list | tuple): vars = [vars] value_vars = [] @@ -926,7 +924,7 @@ def unobserved_RVs(self): return self.free_RVs + self.deterministics @property - def coords(self) -> dict[str, Union[tuple, None]]: + def coords(self) -> dict[str, tuple | None]: """Coordinate values for model dimensions.""" return self._coords @@ -956,10 +954,10 @@ def shape_from_dims(self, dims): def add_coord( self, name: str, - values: Optional[Sequence] = None, - mutable: Optional[bool] = None, + values: Sequence | None = None, + mutable: bool | None = None, *, - length: Optional[Union[int, Variable]] = None, + length: int | Variable | None = None, ): """Registers a dimension coordinate with the model. @@ -1000,7 +998,7 @@ def add_coord( if name in self.coords: if not np.array_equal(values, self.coords[name]): raise ValueError(f"Duplicate and incompatible coordinate: {name}.") - if length is not None and not isinstance(length, (int, Variable)): + if length is not None and not isinstance(length, int | Variable): raise ValueError( f"The `length` passed for the '{name}' coord must be an int, PyTensor Variable or None." ) @@ -1014,9 +1012,9 @@ def add_coord( def add_coords( self, - coords: dict[str, Optional[Sequence]], + coords: dict[str, Sequence | None], *, - lengths: Optional[dict[str, Optional[Union[int, Variable]]]] = None, + lengths: dict[str, int | Variable | None] | None = None, ): """Vectorized version of ``Model.add_coord``.""" if coords is None: @@ -1026,7 +1024,7 @@ def add_coords( for name, values in coords.items(): self.add_coord(name, values, length=lengths.get(name, None)) - def set_dim(self, name: str, new_length: int, coord_values: Optional[Sequence] = None): + def set_dim(self, name: str, new_length: int, coord_values: Sequence | None = None): """Update a mutable dimension. Parameters @@ -1072,7 +1070,7 @@ def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.nd def set_initval(self, rv_var, initval): """Sets an initial value (strategy) for a random variable.""" - if initval is not None and not isinstance(initval, (Variable, str)): + if initval is not None and not isinstance(initval, Variable | str): # Convert scalars or array-like inputs to ndarrays initval = rv_var.type.filter(initval) @@ -1081,8 +1079,8 @@ def set_initval(self, rv_var, initval): def set_data( self, name: str, - values: Union[Sequence, np.ndarray], - coords: Optional[dict[str, Sequence]] = None, + values: Sequence | np.ndarray, + coords: dict[str, Sequence] | None = None, ): """Changes the values of a data variable in the model. @@ -1288,8 +1286,8 @@ def make_obs_var( rv_var: TensorVariable, data: np.ndarray, dims, - transform: Union[Any, None], - total_size: Union[int, None], + transform: Any | None, + total_size: int | None, ) -> TensorVariable: """Create a `TensorVariable` for an observed random variable. @@ -1371,7 +1369,7 @@ def make_obs_var( return rv_var def create_value_var( - self, rv_var: TensorVariable, transform: Any, value_var: Optional[Variable] = None + self, rv_var: TensorVariable, transform: Any, value_var: Variable | None = None ) -> TensorVariable: """Create a ``TensorVariable`` that will be used as the random variable's "value" in log-likelihood graphs. @@ -1429,7 +1427,7 @@ def create_value_var( return value_var - def add_named_variable(self, var, dims: Optional[tuple[Union[str, None], ...]] = None): + def add_named_variable(self, var, dims: tuple[str | None, ...] | None = None): """Add a random graph variable to the named variables of the model. This can include several types of variables such basic_RVs, Data, Deterministics, @@ -1528,13 +1526,13 @@ def replace_rvs_by_values( def compile_fn( self, - outs: Union[Variable, Sequence[Variable]], + outs: Variable | Sequence[Variable], *, - inputs: Optional[Sequence[Variable]] = None, + inputs: Sequence[Variable] | None = None, mode=None, point_fn: bool = True, **kwargs, - ) -> Union[PointFunc, Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]: + ) -> PointFunc | Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]: """Compiles an PyTensor function Parameters @@ -1724,7 +1722,7 @@ def point_logps(self, point=None, round_vals=2): def debug( self, - point: Optional[dict[str, np.ndarray]] = None, + point: dict[str, np.ndarray] | None = None, fn: Literal["logp", "dlogp", "random"] = "logp", verbose: bool = False, ): @@ -1871,10 +1869,10 @@ def debug_parameters(rv): def to_graphviz( self, *, - var_names: Optional[Iterable[VarName]] = None, + var_names: Iterable[VarName] | None = None, formatting: str = "plain", - save: Optional[str] = None, - figsize: Optional[tuple[int, int]] = None, + save: str | None = None, + figsize: tuple[int, int] | None = None, dpi: int = 300, ): """Produce a graphviz Digraph from a PyMC model. @@ -2039,14 +2037,14 @@ def set_data(new_data, model=None, *, coords=None): def compile_fn( - outs: Union[Variable, Sequence[Variable]], + outs: Variable | Sequence[Variable], *, - inputs: Optional[Sequence[Variable]] = None, + inputs: Sequence[Variable] | None = None, mode=None, point_fn: bool = True, - model: Optional[Model] = None, + model: Model | None = None, **kwargs, -) -> Union[PointFunc, Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]]: +) -> PointFunc | Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]: """Compiles an PyTensor function Parameters diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index 49d7d9cbe4..05d8fe4200 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import copy, deepcopy -from typing import Optional import pytensor @@ -58,7 +57,7 @@ def perform(self, *args, **kwargs): class ModelValuedVar(ModelVar): __props__ = ("transform",) - def __init__(self, transform: Optional[Transform] = None): + def __init__(self, transform: Transform | None = None): if transform is not None and not isinstance(transform, Transform): raise TypeError(f"transform must be None or RVTransform type, got {type(transform)}") self.transform = transform @@ -261,7 +260,7 @@ def fgraph_from_model( inverse_memo = {v: k for k, v in memo.items()} for var, model_var in replacements: if not inlined_views and ( - model_var.owner and isinstance(model_var.owner.op, (ModelDeterministic, ModelNamed)) + model_var.owner and isinstance(model_var.owner.op, ModelDeterministic | ModelNamed) ): # Ignore extra identity that will be removed at the end var = var.owner.inputs[0] diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 0ef83397d5..994e99dbf8 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Union from pytensor import Variable from pytensor.graph import ancestors @@ -25,7 +24,7 @@ model_from_fgraph, ) -ModelVariable = Union[Variable, str] +ModelVariable = Variable | str def prune_vars_detached_from_observed(model: Model) -> Model: @@ -54,8 +53,8 @@ def prune_vars_detached_from_observed(model: Model) -> Model: return model_from_fgraph(fgraph) -def parse_vars(model: Model, vars: Union[ModelVariable, Sequence[ModelVariable]]) -> list[Variable]: - if isinstance(vars, (list, tuple)): +def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> list[Variable]: + if isinstance(vars, list | tuple): vars_seq = vars else: vars_seq = (vars,) diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py index b321007c68..8fc184a96c 100644 --- a/pymc/model/transform/conditioning.py +++ b/pymc/model/transform/conditioning.py @@ -14,7 +14,7 @@ import warnings from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Union from pytensor.graph import ancestors from pytensor.tensor import TensorVariable @@ -106,7 +106,7 @@ def observe( model_var = memo[var] # Just a sanity check - assert isinstance(model_var.owner.op, (ModelFreeRV, ModelDeterministic)) + assert isinstance(model_var.owner.op, ModelFreeRV | ModelDeterministic) assert model_var in fgraph.variables var = model_var.owner.inputs[0] @@ -223,7 +223,7 @@ def do( def change_value_transforms( model: Model, - vars_to_transforms: Mapping[ModelVariable, Union[Transform, None]], + vars_to_transforms: Mapping[ModelVariable, Transform | None], ) -> Model: """Change the value variables transforms in the model @@ -307,7 +307,7 @@ def change_value_transforms( def remove_value_transforms( model: Model, - vars: Optional[Sequence[ModelVariable]] = None, + vars: Sequence[ModelVariable] | None = None, ) -> Model: """Remove the value variables transforms in the model diff --git a/pymc/model_graph.py b/pymc/model_graph.py index d9189d552f..e29d371f04 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -16,7 +16,6 @@ from collections import defaultdict from collections.abc import Iterable, Sequence from os import path -from typing import Optional from pytensor import function from pytensor.graph import Apply @@ -83,7 +82,7 @@ def _expand(x): return parents - def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> list[VarName]: + def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]: if var_names is None: return self._all_var_names @@ -114,7 +113,7 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> list[Va return [get_var_name(var) for var in selected_ancestors] def make_compute_graph( - self, var_names: Optional[Iterable[VarName]] = None + self, var_names: Iterable[VarName] | None = None ) -> dict[VarName, set[VarName]]: """Get map of var_name -> set(input var names) for the model""" input_map: dict[VarName, set[VarName]] = defaultdict(set) @@ -194,7 +193,7 @@ def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: st else: graph.node(var_name.replace(":", "&"), **kwargs) - def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> dict[str, set[VarName]]: + def get_plates(self, var_names: Iterable[VarName] | None = None) -> dict[str, set[VarName]]: """Rough but surprisingly accurate plate detection. Just groups by the shape of the underlying distribution. Will be wrong @@ -236,7 +235,7 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> dict[str, def make_graph( self, - var_names: Optional[Iterable[VarName]] = None, + var_names: Iterable[VarName] | None = None, formatting: str = "plain", save=None, figsize=None, @@ -288,9 +287,7 @@ def make_graph( return graph - def make_networkx( - self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain" - ): + def make_networkx(self, var_names: Iterable[VarName] | None = None, formatting: str = "plain"): """Make networkx Digraph of PyMC model Returns @@ -347,7 +344,7 @@ def make_networkx( def model_to_networkx( model=None, *, - var_names: Optional[Iterable[VarName]] = None, + var_names: Iterable[VarName] | None = None, formatting: str = "plain", ): """Produce a networkx Digraph from a PyMC model. @@ -412,10 +409,10 @@ def model_to_networkx( def model_to_graphviz( model=None, *, - var_names: Optional[Iterable[VarName]] = None, + var_names: Iterable[VarName] | None = None, formatting: str = "plain", - save: Optional[str] = None, - figsize: Optional[tuple[int, int]] = None, + save: str | None = None, + figsize: tuple[int, int] | None = None, dpi: int = 300, ): """Produce a graphviz Digraph from a PyMC model. diff --git a/pymc/ode/ode.py b/pymc/ode/ode.py index 600f30632e..c38f6cf8bf 100644 --- a/pymc/ode/ode.py +++ b/pymc/ode/ode.py @@ -149,9 +149,9 @@ def make_node(self, y0, theta): return Apply(self, inputs, (states, sens)) def __call__(self, y0, theta, return_sens=False, **kwargs): - if isinstance(y0, (list, tuple)) and not len(y0) == self.n_states: + if isinstance(y0, list | tuple) and not len(y0) == self.n_states: raise ShapeError("Length of y0 is wrong.", actual=(len(y0),), expected=(self.n_states,)) - if isinstance(theta, (list, tuple)) and not len(theta) == self.n_theta: + if isinstance(theta, list | tuple) and not len(theta) == self.n_theta: raise ShapeError( "Length of theta is wrong.", actual=(len(theta),), expected=(self.n_theta,) ) diff --git a/pymc/ode/utils.py b/pymc/ode/utils.py index 8bf4f7deb3..1ccf7e5ba3 100644 --- a/pymc/ode/utils.py +++ b/pymc/ode/utils.py @@ -107,7 +107,7 @@ def augment_system(ode_func, n_states, n_theta): t_yhat = pt.atleast_1d(yhat) else: # Stack the results of the ode_func into a single tensor variable - if not isinstance(yhat, (list, tuple)): + if not isinstance(yhat, list | tuple): raise TypeError( f"Unexpected type, {type(yhat)}, returned by ode_func. TensorVariable, list or tuple is expected." ) diff --git a/pymc/printing.py b/pymc/printing.py index 9fe7d056cf..13361741e6 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union from pytensor.compile import SharedVariable from pytensor.graph.basic import Constant, walk @@ -49,7 +48,7 @@ def str_for_dist( dist_args = [ _str_for_input_var(x, formatting=formatting) for x in dist.owner.inputs - if not isinstance(x, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) + if not isinstance(x, RandomStateSharedVariable | RandomGeneratorSharedVariable) ] print_name = dist.name @@ -169,10 +168,10 @@ def _is_potential_or_deterministic(var: Variable) -> bool: # in case other code overrides str_repr, fallback return False - if isinstance(var, (Constant, SharedVariable)): + if isinstance(var, Constant | SharedVariable): return _str_for_constant(var, formatting) elif isinstance( - var.owner.op, (RandomVariable, SymbolicRandomVariable) + var.owner.op, RandomVariable | SymbolicRandomVariable ) or _is_potential_or_deterministic(var): # show the names for RandomVariables, Deterministics, and Potentials, rather # than the full expression @@ -195,7 +194,7 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str: return _str -def _str_for_constant(var: Union[Constant, SharedVariable], formatting: str) -> str: +def _str_for_constant(var: Constant | SharedVariable, formatting: str) -> str: if isinstance(var, Constant): var_data = var.data var_type = "constant" @@ -219,13 +218,13 @@ def _str_for_expression(var: Variable, formatting: str) -> str: # construct a string like f(a1, ..., aN) listing all random variables a as arguments def _expand(x): - if x.owner and (not isinstance(x.owner.op, (RandomVariable, SymbolicRandomVariable))): + if x.owner and (not isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable)): return reversed(x.owner.inputs) parents = [ x for x in walk(nodes=var.owner.inputs, expand=_expand) - if x.owner and isinstance(x.owner.op, (RandomVariable, SymbolicRandomVariable)) + if x.owner and isinstance(x.owner.op, RandomVariable | SymbolicRandomVariable) ] names = [x.name for x in parents] @@ -254,7 +253,7 @@ def _latex_escape(text: str) -> str: return text.replace("$", r"\$") -def _default_repr_pretty(obj: Union[TensorVariable, Model], p, cycle): +def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): """Handy plug-in method to instruct IPython-like REPLs to use our str_repr above.""" # we know that our str_repr does not recurse, so we can ignore cycle try: diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index ef4aa2a157..a68cd64d8a 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -13,12 +13,7 @@ # limitations under the License. import warnings -from collections.abc import Generator, Iterable, Sequence -from typing import ( - Callable, - Optional, - Union, -) +from collections.abc import Callable, Generator, Iterable, Sequence import numpy as np import pandas as pd @@ -53,13 +48,13 @@ from pytensor.tensor.rewriting.shape import ShapeFeature from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 -from pytensor.tensor.variable import TensorConstant, TensorVariable +from pytensor.tensor.variable import TensorVariable from pymc.exceptions import NotConstantValueError from pymc.util import makeiter from pymc.vartypes import continuous_types, isgenerator, typefilter -PotentialShapeType = Union[int, np.ndarray, Sequence[Union[int, Variable]], TensorVariable] +PotentialShapeType = int | np.ndarray | Sequence[int | Variable] | TensorVariable __all__ = [ @@ -157,7 +152,7 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: if x.owner and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast): array_data = extract_obs_data(x.owner.inputs[0]) return array_data.astype(x.type.dtype) - if x.owner and isinstance(x.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)): + if x.owner and isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1): array_data = extract_obs_data(x.owner.inputs[0]) mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:]) mask = np.zeros_like(array_data) @@ -169,7 +164,7 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: def walk_model( graphs: Iterable[TensorVariable], - stop_at_vars: Optional[set[TensorVariable]] = None, + stop_at_vars: set[TensorVariable] | None = None, expand_fn: Callable[[TensorVariable], Iterable[TensorVariable]] = lambda var: [], ) -> Generator[TensorVariable, None, None]: """Walk model graphs and yield their nodes. @@ -245,7 +240,7 @@ def inputvars(a): return [ v for v in graph_inputs(makeiter(a)) - if isinstance(v, TensorVariable) and not isinstance(v, TensorConstant) + if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable) ] @@ -434,7 +429,7 @@ def join_nonshared_inputs( point: dict[str, np.ndarray], outputs: list[TensorVariable], inputs: list[TensorVariable], - shared_inputs: Optional[dict[TensorVariable, TensorSharedVariable]] = None, + shared_inputs: dict[TensorVariable, TensorSharedVariable] | None = None, make_inputs_shared: bool = False, ) -> tuple[list[TensorVariable], TensorVariable]: """ @@ -732,12 +727,12 @@ def largest_common_dtype(tensors): def find_rng_nodes( variables: Iterable[Variable], -) -> list[Union[RandomStateSharedVariable, RandomGeneratorSharedVariable]]: +) -> list[RandomStateSharedVariable | RandomGeneratorSharedVariable]: """Return RNG variables in a graph""" return [ node for node in graph_inputs(variables) - if isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) + if isinstance(node, RandomStateSharedVariable | RandomGeneratorSharedVariable) ] @@ -754,7 +749,7 @@ def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVaria return outputs graph = FunctionGraph(outputs=outputs, clone=False) - new_rng_nodes: list[Union[np.random.RandomState, np.random.Generator]] = [] + new_rng_nodes: list[np.random.RandomState | np.random.Generator] = [] for rng_node in rng_nodes: rng_cls: type if isinstance(rng_node, pt.random.var.RandomStateSharedVariable): @@ -766,7 +761,7 @@ def replace_rng_nodes(outputs: Sequence[TensorVariable]) -> Sequence[TensorVaria return graph.outputs -SeedSequenceSeed = Optional[Union[int, Sequence[int], np.ndarray, np.random.SeedSequence]] +SeedSequenceSeed = None | int | Sequence[int] | np.ndarray | np.random.SeedSequence def reseed_rngs( @@ -778,7 +773,7 @@ def reseed_rngs( np.random.PCG64(sub_seed) for sub_seed in np.random.SeedSequence(seed).spawn(len(rngs)) ] for rng, bit_generator in zip(rngs, bit_generators): - new_rng: Union[np.random.RandomState, np.random.Generator] + new_rng: np.random.RandomState | np.random.Generator if isinstance(rng, pt.random.var.RandomStateSharedVariable): new_rng = np.random.RandomState(bit_generator) else: @@ -789,7 +784,7 @@ def reseed_rngs( def collect_default_updates( outputs: Sequence[Variable], *, - inputs: Optional[Sequence[Variable]] = None, + inputs: Sequence[Variable] | None = None, must_be_shared: bool = True, ) -> dict[Variable, Variable]: """Collect default update expression for shared-variable RNGs used by RVs between inputs and outputs. @@ -834,7 +829,7 @@ def scan_step(xtm1): # Avoid circular import from pymc.distributions.distribution import SymbolicRandomVariable - def find_default_update(clients, rng: Variable) -> Union[None, Variable]: + def find_default_update(clients, rng: Variable) -> None | Variable: rng_clients = clients.get(rng, None) # Root case, RNG is not used elsewhere diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 1a8116ecd2..814ed8de92 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -17,12 +17,10 @@ import logging import warnings -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from typing import ( Any, - Callable, - Optional, - Union, + TypeAlias, cast, ) @@ -48,7 +46,6 @@ from rich.console import Console from rich.progress import Progress from rich.theme import Theme -from typing_extensions import TypeAlias import pymc as pm @@ -73,7 +70,7 @@ "sample_posterior_predictive", ) -ArrayLike: TypeAlias = Union[np.ndarray, list[float]] +ArrayLike: TypeAlias = np.ndarray | list[float] PointList: TypeAlias = list[PointType] _log = logging.getLogger(__name__) @@ -93,12 +90,12 @@ def get_vars_in_point_list(trace, model): def compile_forward_sampling_function( outputs: list[Variable], vars_in_trace: list[Variable], - basic_rvs: Optional[list[Variable]] = None, - givens_dict: Optional[dict[Variable, Any]] = None, - constant_data: Optional[dict[str, np.ndarray]] = None, - constant_coords: Optional[set[str]] = None, + basic_rvs: list[Variable] | None = None, + givens_dict: dict[Variable, Any] | None = None, + constant_data: dict[str, np.ndarray] | None = None, + constant_coords: set[str] | None = None, **kwargs, -) -> tuple[Callable[..., Union[np.ndarray, list[np.ndarray]]], set[Variable]]: +) -> tuple[Callable[..., np.ndarray | list[np.ndarray]], set[Variable]]: """Compile a function to draw samples, conditioned on the values of some variables. The goal of this function is to walk the pytensor computational graph from the list @@ -210,7 +207,7 @@ def shared_value_matches(var): or node in givens_dict or ( # SharedVariables, except RandomState/Generators isinstance(node, SharedVariable) - and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) + and not isinstance(node, RandomStateSharedVariable | RandomGeneratorSharedVariable) and not shared_value_matches(node) ) or ( # Basic RVs that are not in the trace @@ -230,7 +227,7 @@ def shared_value_matches(var): def expand(node): if ( ( - node.owner is None and not isinstance(node, (Constant, SharedVariable)) + node.owner is None and not isinstance(node, Constant | SharedVariable) ) # Variables without owners that are not constant or shared or node in vars_in_trace # Variables in the trace ) and node not in volatile_nodes: @@ -249,7 +246,7 @@ def expand(node): ( node, value - if isinstance(value, (Variable, Apply)) + if isinstance(value, Variable | Apply) else pt.constant(value, dtype=getattr(node, "dtype", None), name=node.name), ) for node, value in givens_dict.items() @@ -262,11 +259,11 @@ def expand(node): def draw( - vars: Union[Variable, Sequence[Variable]], + vars: Variable | Sequence[Variable], draws: int = 1, random_seed: RandomState = None, **kwargs, -) -> Union[np.ndarray, list[np.ndarray]]: +) -> np.ndarray | list[np.ndarray]: """Draw samples for one variable or a list of variables Parameters @@ -318,7 +315,7 @@ def draw( return draw_fn() # Single variable output - if not isinstance(vars, (list, tuple)): + if not isinstance(vars, list | tuple): cast(Callable[[], np.ndarray], draw_fn) return np.stack([draw_fn() for _ in range(draws)]) @@ -342,13 +339,13 @@ def observed_dependent_deterministics(model: Model): def sample_prior_predictive( samples: int = 500, - model: Optional[Model] = None, - var_names: Optional[Iterable[str]] = None, + model: Model | None = None, + var_names: Iterable[str] | None = None, random_seed: RandomState = None, return_inferencedata: bool = True, - idata_kwargs: Optional[dict] = None, - compile_kwargs: Optional[dict] = None, -) -> Union[InferenceData, dict[str, np.ndarray]]: + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, +) -> InferenceData | dict[str, np.ndarray]: """Generate samples from the prior predictive distribution. Parameters @@ -439,18 +436,18 @@ def sample_prior_predictive( def sample_posterior_predictive( trace, - model: Optional[Model] = None, - var_names: Optional[list[str]] = None, - sample_dims: Optional[list[str]] = None, + model: Model | None = None, + var_names: list[str] | None = None, + sample_dims: list[str] | None = None, random_seed: RandomState = None, progressbar: bool = True, - progressbar_theme: Optional[Theme] = default_progress_theme, + progressbar_theme: Theme | None = default_progress_theme, return_inferencedata: bool = True, extend_inferencedata: bool = False, predictions: bool = False, - idata_kwargs: Optional[dict] = None, - compile_kwargs: Optional[dict] = None, -) -> Union[InferenceData, dict[str, np.ndarray]]: + idata_kwargs: dict | None = None, + compile_kwargs: dict | None = None, +) -> InferenceData | dict[str, np.ndarray]: """Generate forward samples for `var_names`, conditioned on the posterior samples of variables found in the `trace`. This method can be used to perform different kinds of model predictions, including posterior predictive checks. @@ -726,7 +723,7 @@ def sample_posterior_predictive( """ - _trace: Union[MultiTrace, PointList] + _trace: MultiTrace | PointList nchain: int if idata_kwargs is None: idata_kwargs = {} @@ -738,7 +735,7 @@ def sample_posterior_predictive( trace_coords: dict[str, np.ndarray] = {} if "coords" not in idata_kwargs: idata_kwargs["coords"] = {} - idata: Optional[InferenceData] = None + idata: InferenceData | None = None stacked_dims = None if isinstance(trace, InferenceData): _constant_data = getattr(trace, "constant_data", None) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index f048cc2938..4f305139ea 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -15,10 +15,10 @@ import os import re -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import datetime from functools import partial -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Literal import arviz as az import jax @@ -120,8 +120,8 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl def get_jaxified_graph( - inputs: Optional[list[TensorVariable]] = None, - outputs: Optional[list[TensorVariable]] = None, + inputs: list[TensorVariable] | None = None, + outputs: list[TensorVariable] | None = None, ) -> list[TensorVariable]: """Compile an PyTensor graph into an optimized JAX function""" @@ -161,7 +161,7 @@ def logp_fn_wrap(x): def _get_log_likelihood( model: Model, samples, - backend: Optional[Literal["cpu", "gpu"]] = None, + backend: Literal["cpu", "gpu"] | None = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", ) -> dict: """Compute log-likelihood for all observations""" @@ -180,7 +180,7 @@ def _device_put(input, device: str): def _postprocess_samples( jax_fn: Callable, raw_mcmc_samples: list[TensorVariable], - postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None, + postprocessing_backend: Literal["cpu", "gpu"] | None = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", ) -> list[TensorVariable]: if postprocessing_vectorize == "scan": @@ -201,11 +201,11 @@ def _postprocess_samples( def _get_batched_jittered_initial_points( model: Model, chains: int, - initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], + initvals: StartDict | Sequence[StartDict | None] | None, random_seed: RandomSeed, jitter: bool = True, jitter_max_retries: int = 10, -) -> Union[np.ndarray, list[np.ndarray]]: +) -> np.ndarray | list[np.ndarray]: """Get jittered initial point in format expected by NumPyro MCMC kernel Returns @@ -309,7 +309,7 @@ def _sample_blackjax_nuts( tune: int, draws: int, chains: int, - chain_method: Optional[str], + chain_method: str | None, progressbar: bool, random_seed: int, initial_points, @@ -429,7 +429,7 @@ def _numpyro_stats_to_dict(posterior): } data = {} for stat, value in posterior.get_extra_fields(group_by_chain=True).items(): - if isinstance(value, (dict, tuple)): + if isinstance(value, dict | tuple): continue name = rename_key.get(stat, stat) value = value.copy() @@ -445,7 +445,7 @@ def _sample_numpyro_nuts( tune: int, draws: int, chains: int, - chain_method: Optional[str], + chain_method: str | None, progressbar: bool, random_seed: int, initial_points, @@ -505,19 +505,19 @@ def sample_jax_nuts( tune: int = 1000, chains: int = 4, target_accept: float = 0.8, - random_seed: Optional[RandomState] = None, - initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, + random_seed: RandomState | None = None, + initvals: StartDict | Sequence[StartDict | None] | None = None, jitter: bool = True, - model: Optional[Model] = None, - var_names: Optional[Sequence[str]] = None, - nuts_kwargs: Optional[dict] = None, + model: Model | None = None, + var_names: Sequence[str] | None = None, + nuts_kwargs: dict | None = None, progressbar: bool = True, keep_untransformed: bool = False, chain_method: str = "parallel", - postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None, + postprocessing_backend: Literal["cpu", "gpu"] | None = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", postprocessing_chunks=None, - idata_kwargs: Optional[dict] = None, + idata_kwargs: dict | None = None, compute_convergence_checks: bool = True, nuts_sampler: Literal["numpyro", "blackjax"], ) -> az.InferenceData: diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 1241b10b86..7f750090f8 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -24,8 +24,7 @@ from typing import ( Any, Literal, - Optional, - Union, + TypeAlias, overload, ) @@ -38,7 +37,7 @@ from rich.console import Console from rich.progress import Progress from rich.theme import Theme -from typing_extensions import Protocol, TypeAlias +from typing_extensions import Protocol import pymc as pm @@ -81,7 +80,7 @@ "init_nuts", ] -Step: TypeAlias = Union[BlockedStep, CompoundStep] +Step: TypeAlias = BlockedStep | CompoundStep class SamplingIteratorCallback(Protocol): @@ -98,8 +97,8 @@ def instantiate_steppers( model: Model, steps: list[Step], selected_steps: Mapping[type[BlockedStep], list[Any]], - step_kwargs: Optional[dict[str, dict]] = None, -) -> Union[Step, list[Step]]: + step_kwargs: dict[str, dict] | None = None, +) -> Step | list[Step]: """Instantiate steppers assigned to the model variables. This function is intended to be called automatically from ``sample()``, but @@ -156,10 +155,10 @@ def instantiate_steppers( def assign_step_methods( model: Model, - step: Optional[Union[Step, Sequence[Step]]] = None, - methods: Optional[Sequence[type[BlockedStep]]] = None, - step_kwargs: Optional[dict[str, Any]] = None, -) -> Union[Step, list[Step]]: + step: Step | Sequence[Step] | None = None, + methods: Sequence[type[BlockedStep]] | None = None, + step_kwargs: dict[str, Any] | None = None, +) -> Step | list[Step]: """Assign model variables to appropriate step methods. Passing a specified model will auto-assign its constituent stochastic @@ -193,7 +192,7 @@ def assign_step_methods( assigned_vars: set[Variable] = set() if step is not None: - if isinstance(step, (BlockedStep, CompoundStep)): + if isinstance(step, BlockedStep | CompoundStep): steps.append(step) else: steps.extend(step) @@ -264,13 +263,13 @@ def _sample_external_nuts( tune: int, chains: int, target_accept: float, - random_seed: Union[RandomState, None], - initvals: Union[StartDict, Sequence[Optional[StartDict]], None], + random_seed: RandomState | None, + initvals: StartDict | Sequence[StartDict | None] | None, model: Model, - var_names: Optional[Sequence[str]], + var_names: Sequence[str] | None, progressbar: bool, - idata_kwargs: Optional[dict], - nuts_sampler_kwargs: Optional[dict], + idata_kwargs: dict | None, + nuts_sampler_kwargs: dict | None, **kwargs, ): if nuts_sampler_kwargs is None: @@ -376,25 +375,25 @@ def sample( draws: int = 1000, *, tune: int = 1000, - chains: Optional[int] = None, - cores: Optional[int] = None, + chains: int | None = None, + cores: int | None = None, random_seed: RandomState = None, progressbar: bool = True, - progressbar_theme: Optional[Theme] = default_progress_theme, + progressbar_theme: Theme | None = default_progress_theme, step=None, - var_names: Optional[Sequence[str]] = None, + var_names: Sequence[str] | None = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", - initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, + initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, n_init: int = 200_000, - trace: Optional[TraceOrBackend] = None, + trace: TraceOrBackend | None = None, discard_tuned_samples: bool = True, compute_convergence_checks: bool = True, keep_warning_stat: bool = False, return_inferencedata: Literal[True] = True, - idata_kwargs: Optional[dict[str, Any]] = None, - nuts_sampler_kwargs: Optional[dict[str, Any]] = None, + idata_kwargs: dict[str, Any] | None = None, + nuts_sampler_kwargs: dict[str, Any] | None = None, callback=None, mp_ctx=None, **kwargs, @@ -406,28 +405,28 @@ def sample( draws: int = 1000, *, tune: int = 1000, - chains: Optional[int] = None, - cores: Optional[int] = None, + chains: int | None = None, + cores: int | None = None, random_seed: RandomState = None, progressbar: bool = True, - progressbar_theme: Optional[Theme] = default_progress_theme, + progressbar_theme: Theme | None = default_progress_theme, step=None, - var_names: Optional[Sequence[str]] = None, + var_names: Sequence[str] | None = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", - initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, + initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, n_init: int = 200_000, - trace: Optional[TraceOrBackend] = None, + trace: TraceOrBackend | None = None, discard_tuned_samples: bool = True, compute_convergence_checks: bool = True, keep_warning_stat: bool = False, return_inferencedata: Literal[False], - idata_kwargs: Optional[dict[str, Any]] = None, - nuts_sampler_kwargs: Optional[dict[str, Any]] = None, + idata_kwargs: dict[str, Any] | None = None, + nuts_sampler_kwargs: dict[str, Any] | None = None, callback=None, mp_ctx=None, - model: Optional[Model] = None, + model: Model | None = None, **kwargs, ) -> MultiTrace: ... @@ -436,30 +435,30 @@ def sample( draws: int = 1000, *, tune: int = 1000, - chains: Optional[int] = None, - cores: Optional[int] = None, + chains: int | None = None, + cores: int | None = None, random_seed: RandomState = None, progressbar: bool = True, - progressbar_theme: Optional[Theme] = default_progress_theme, + progressbar_theme: Theme | None = default_progress_theme, step=None, - var_names: Optional[Sequence[str]] = None, + var_names: Sequence[str] | None = None, nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", - initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, + initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, n_init: int = 200_000, - trace: Optional[TraceOrBackend] = None, + trace: TraceOrBackend | None = None, discard_tuned_samples: bool = True, compute_convergence_checks: bool = True, keep_warning_stat: bool = False, return_inferencedata: bool = True, - idata_kwargs: Optional[dict[str, Any]] = None, - nuts_sampler_kwargs: Optional[dict[str, Any]] = None, + idata_kwargs: dict[str, Any] | None = None, + nuts_sampler_kwargs: dict[str, Any] | None = None, callback=None, mp_ctx=None, - model: Optional[Model] = None, + model: Model | None = None, **kwargs, -) -> Union[InferenceData, MultiTrace]: +) -> InferenceData | MultiTrace: r"""Draw samples from the posterior using the given step methods. Multiple step methods are supported via compound step methods. @@ -846,7 +845,7 @@ def sample( def _sample_return( *, - run: Optional[RunType], + run: RunType | None, traces: Sequence[IBaseTrace], tune: int, t_sampling: float, @@ -856,7 +855,7 @@ def _sample_return( keep_warning_stat: bool, idata_kwargs: dict[str, Any], model: Model, -) -> Union[InferenceData, MultiTrace]: +) -> InferenceData | MultiTrace: """Final step of `pm.sampler` that picks/slices chains, runs diagnostics and converts to the desired return type.""" # Pick and slice chains to keep the maximum number of samples @@ -945,9 +944,9 @@ def _sample_many( chains: int, traces: Sequence[IBaseTrace], start: Sequence[PointType], - random_seed: Optional[Sequence[RandomSeed]], + random_seed: Sequence[RandomSeed] | None, step: Step, - callback: Optional[SamplingIteratorCallback] = None, + callback: SamplingIteratorCallback | None = None, **kwargs, ): """Samples all chains sequentially. @@ -989,8 +988,8 @@ def _sample( step: Step, trace: IBaseTrace, tune: int, - model: Optional[Model] = None, - progressbar_theme: Optional[Theme] = default_progress_theme, + model: Model | None = None, + progressbar_theme: Theme | None = default_progress_theme, callback=None, **kwargs, ) -> None: @@ -1056,9 +1055,9 @@ def _iter_sample( trace: IBaseTrace, chain: int = 0, tune: int = 0, - model: Optional[Model] = None, + model: Model | None = None, random_seed: RandomSeed = None, - callback: Optional[SamplingIteratorCallback] = None, + callback: SamplingIteratorCallback | None = None, ) -> Iterator[bool]: """Generator for sampling one chain. (Used in singleprocess sampling.) @@ -1138,10 +1137,10 @@ def _mp_sample( random_seed: Sequence[RandomSeed], start: Sequence[PointType], progressbar: bool = True, - progressbar_theme: Optional[Theme] = default_progress_theme, + progressbar_theme: Theme | None = default_progress_theme, traces: Sequence[IBaseTrace], - model: Optional[Model] = None, - callback: Optional[SamplingIteratorCallback] = None, + model: Model | None = None, + callback: SamplingIteratorCallback | None = None, mp_ctx=None, **kwargs, ) -> None: @@ -1222,8 +1221,8 @@ def _mp_sample( def _init_jitter( model: Model, - initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]], - seeds: Union[Sequence[int], np.ndarray], + initvals: StartDict | Sequence[StartDict | None] | None, + seeds: Sequence[int] | np.ndarray, jitter: bool, jitter_max_retries: int, ) -> list[PointType]: @@ -1279,12 +1278,12 @@ def init_nuts( init: str = "auto", chains: int = 1, n_init: int = 500_000, - model: Optional[Model] = None, + model: Model | None = None, random_seed: RandomSeed = None, progressbar=True, jitter_max_retries: int = 10, - tune: Optional[int] = None, - initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, + tune: int | None = None, + initvals: StartDict | Sequence[StartDict | None] | None = None, **kwargs, ) -> tuple[Sequence[PointType], NUTS]: """Set up the mass matrix initialization for NUTS. diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 29ccc1a0d0..cc6908647e 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -22,7 +22,6 @@ from collections import namedtuple from collections.abc import Sequence -from typing import Optional import cloudpickle import numpy as np @@ -378,7 +377,7 @@ def __init__( start_points: Sequence[dict[str, np.ndarray]], step_method, progressbar: bool = True, - progressbar_theme: Optional[Theme] = default_progress_theme, + progressbar_theme: Theme | None = default_progress_theme, mp_ctx=None, ): if any(len(arg) != chains for arg in [seeds, start_points]): diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 2a0db2ecfa..1627bb8de7 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -19,13 +19,12 @@ from collections.abc import Iterator, Sequence from copy import copy -from typing import Union +from typing import TypeAlias import cloudpickle import numpy as np from rich.progress import BarColumn, Progress, TimeRemainingColumn -from typing_extensions import TypeAlias from pymc.backends.base import BaseTrace from pymc.initial_point import PointType @@ -43,7 +42,7 @@ __all__ = () -Step: TypeAlias = Union[BlockedStep, CompoundStep] +Step: TypeAlias = BlockedStep | CompoundStep _log = logging.getLogger(__name__) @@ -55,7 +54,7 @@ def _sample_population( draws: int, start: Sequence[PointType], random_seed: RandomSeed, - step: Union[BlockedStep, CompoundStep], + step: BlockedStep | CompoundStep, tune: int, model: Model, progressbar: bool = True, @@ -112,7 +111,7 @@ def _sample_population( def warn_population_size( *, - step: Union[BlockedStep, CompoundStep], + step: BlockedStep | CompoundStep, initial_points: Sequence[PointType], model: Model, chains: int, diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 3e0a3f3e47..61aa403c89 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -16,7 +16,7 @@ import warnings from abc import ABC -from typing import Union, cast +from typing import TypeAlias, cast import numpy as np import pytensor.tensor as pt @@ -24,7 +24,6 @@ from pytensor.graph.replace import clone_replace from scipy.special import logsumexp from scipy.stats import multivariate_normal -from typing_extensions import TypeAlias from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection @@ -40,8 +39,8 @@ from pymc.step_methods.metropolis import MultivariateNormalProposal from pymc.vartypes import discrete_types -SMCStats: TypeAlias = dict[str, Union[int, float]] -SMCSettings: TypeAlias = dict[str, Union[int, float]] +SMCStats: TypeAlias = dict[str, int | float] +SMCSettings: TypeAlias = dict[str, int | float] class SMC_KERNEL(ABC): diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index e5129e8fce..d9b76f211c 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -19,7 +19,7 @@ from collections import defaultdict from concurrent.futures import ProcessPoolExecutor -from typing import Any, Optional, Union +from typing import Any import cloudpickle import numpy as np @@ -52,7 +52,7 @@ def sample_smc( idata_kwargs=None, progressbar=True, **kernel_kwargs, -) -> Union[InferenceData, MultiTrace]: +) -> InferenceData | MultiTrace: r""" Sequential Monte Carlo based sampling. @@ -253,7 +253,7 @@ def _save_sample_stats( _t_sampling, idata_kwargs, model: Model, -) -> tuple[Optional[Any], Optional[InferenceData]]: +) -> tuple[Any | None, InferenceData | None]: sample_settings_dict = sample_settings[0] sample_settings_dict["_t_sampling"] = _t_sampling sample_stats_dict = sample_stats[0] @@ -266,7 +266,7 @@ def _save_sample_stats( value_list.append(chain_sample_stats[stat]) sample_stats_dict[stat] = value_list - idata: Optional[InferenceData] = None + idata: InferenceData | None = None if not return_inferencedata: for stat, value in sample_stats_dict.items(): setattr(trace.report, stat, value) diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py index 6c7c3c8f61..47359365eb 100644 --- a/pymc/stats/convergence.py +++ b/pymc/stats/convergence.py @@ -16,7 +16,7 @@ import logging from collections.abc import Sequence -from typing import Any, Optional +from typing import Any import arviz @@ -53,12 +53,12 @@ class SamplerWarning: kind: WarningType message: str level: str - step: Optional[int] = None - exec_info: Optional[Any] = None - extra: Optional[Any] = None - divergence_point_source: Optional[dict] = None - divergence_point_dest: Optional[dict] = None - divergence_info: Optional[Any] = None + step: int | None = None + exec_info: Any | None = None + extra: Any | None = None + divergence_point_source: dict | None = None + divergence_point_dest: dict | None = None + divergence_info: Any | None = None def run_convergence_checks(idata: arviz.InferenceData, model) -> list[SamplerWarning]: diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py index daf172342f..b4a6ca742b 100644 --- a/pymc/stats/log_density.py +++ b/pymc/stats/log_density.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Optional, cast +from typing import cast from arviz import InferenceData, dict_to_dataset from rich.console import Console @@ -31,9 +31,9 @@ def compute_log_likelihood( idata: InferenceData, *, - var_names: Optional[Sequence[str]] = None, + var_names: Sequence[str] | None = None, extend_inferencedata: bool = True, - model: Optional[Model] = None, + model: Model | None = None, sample_dims: Sequence[str] = ("chain", "draw"), progressbar=True, ): @@ -70,9 +70,9 @@ def compute_log_likelihood( def compute_log_prior( idata: InferenceData, - var_names: Optional[Sequence[str]] = None, + var_names: Sequence[str] | None = None, extend_inferencedata: bool = True, - model: Optional[Model] = None, + model: Model | None = None, sample_dims: Sequence[str] = ("chain", "draw"), progressbar=True, ): @@ -110,9 +110,9 @@ def compute_log_prior( def compute_log_density( idata: InferenceData, *, - var_names: Optional[Sequence[str]] = None, + var_names: Sequence[str] | None = None, extend_inferencedata: bool = True, - model: Optional[Model] = None, + model: Model | None = None, kind="likelihood", sample_dims: Sequence[str] = ("chain", "draw"), progressbar=True, diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index b2b73bbea4..ca6036ecc6 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -13,7 +13,8 @@ # limitations under the License. from abc import abstractmethod -from typing import Callable, Union, cast +from collections.abc import Callable +from typing import cast import numpy as np @@ -47,7 +48,7 @@ def __init__(self, vars, fs, allvars=False, blocked=True): self.blocked = blocked def step(self, point: PointType) -> tuple[PointType, StatsType]: - partial_funcs_and_point: list[Union[Callable, PointType]] = [ + partial_funcs_and_point: list[Callable | PointType] = [ DictToArrayBijection.mapf(x, start_point=point) for x in self.fs ] if self.allvars: diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 403b14e8dd..7c0d8563ca 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -23,7 +23,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping, Sequence from enum import IntEnum, unique -from typing import Any, Union +from typing import Any import numpy as np @@ -126,7 +126,7 @@ def __new__(cls, *args, **kwargs): else: # Assume all model variables vars = model.value_vars - if not isinstance(vars, (tuple, list)): + if not isinstance(vars, tuple | list): vars = [vars] if len(vars) == 0: @@ -251,7 +251,7 @@ def vars(self) -> list[Variable]: return [var for method in self.methods for var in method.vars] -def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> list[BlockedStep]: +def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: """Flatten a hierarchy of step methods to a list.""" if isinstance(step, BlockedStep): return [step] @@ -263,7 +263,7 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> list[BlockedStep]: return steps -def check_step_emits_tune(step: Union[CompoundStep, BlockedStep]): +def check_step_emits_tune(step: CompoundStep | BlockedStep): if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes: raise TypeError(f"{type(step)} does not emit the required 'tune' stat.") elif isinstance(step, CompoundStep): diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 4a595d3700..6c3f2b8a09 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from collections.abc import Callable import numpy as np import numpy.random as nr @@ -67,29 +67,29 @@ def __init__(self, s): class NormalProposal(Proposal): - def __call__(self, rng: Optional[np.random.Generator] = None): + def __call__(self, rng: np.random.Generator | None = None): return (rng or nr).normal(scale=self.s) class UniformProposal(Proposal): - def __call__(self, rng: Optional[np.random.Generator] = None): + def __call__(self, rng: np.random.Generator | None = None): return (rng or nr).uniform(low=-self.s, high=self.s, size=len(self.s)) class CauchyProposal(Proposal): - def __call__(self, rng: Optional[np.random.Generator] = None): + def __call__(self, rng: np.random.Generator | None = None): return (rng or nr).standard_cauchy(size=np.size(self.s)) * self.s class LaplaceProposal(Proposal): - def __call__(self, rng: Optional[np.random.Generator] = None): + def __call__(self, rng: np.random.Generator | None = None): size = np.size(self.s) r = rng or nr return (r.standard_exponential(size=size) - r.standard_exponential(size=size)) * self.s class PoissonProposal(Proposal): - def __call__(self, rng: Optional[np.random.Generator] = None): + def __call__(self, rng: np.random.Generator | None = None): return (rng or nr).poisson(lam=self.s, size=np.size(self.s)) - self.s @@ -101,7 +101,7 @@ def __init__(self, s): self.n = n self.chol = scipy.linalg.cholesky(s, lower=True) - def __call__(self, num_draws=None, rng: Optional[np.random.Generator] = None): + def __call__(self, num_draws=None, rng: np.random.Generator | None = None): rng_ = rng or nr if num_draws is not None: b = rng_.normal(size=(self.n, num_draws)) @@ -767,7 +767,7 @@ def __init__( proposal_dist=None, lamb=None, scaling=0.001, - tune: Optional[str] = "scaling", + tune: str | None = "scaling", tune_interval=100, model=None, mode=None, @@ -910,7 +910,7 @@ def __init__( proposal_dist=None, lamb=None, scaling=0.001, - tune: Optional[str] = "scaling", + tune: str | None = "scaling", tune_interval=100, tune_drop_fraction: float = 0.9, model=None, diff --git a/pymc/testing.py b/pymc/testing.py index f7704c9540..74b581196e 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -15,8 +15,8 @@ import itertools as it import warnings -from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from collections.abc import Callable, Sequence +from typing import Any import numpy as np import pytensor @@ -250,7 +250,7 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None): def create_dist_from_paramdomains( pymc_dist: Distribution, paramdomains: dict[str, Domain], - extra_args: Optional[dict[str, Any]] = None, + extra_args: dict[str, Any] | None = None, ) -> TensorVariable: """Create a PyMC distribution from a dictionary of parameter domains. @@ -273,7 +273,7 @@ def create_dist_from_paramdomains( def find_invalid_scalar_params( paramdomains: dict["str", Domain], -) -> dict["str", tuple[Union[None, float], Union[None, float]]]: +) -> dict["str", tuple[None | float, None | float]]: """Find invalid parameter values from bounded scalar parameter domains. For use in `check_logp`-like testing helpers. @@ -304,10 +304,10 @@ def check_logp( domain: Domain, paramdomains: dict[str, Domain], scipy_logp: Callable, - decimal: Optional[int] = None, + decimal: int | None = None, n_samples: int = 100, - extra_args: Optional[dict[str, Any]] = None, - scipy_args: Optional[dict[str, Any]] = None, + extra_args: dict[str, Any] | None = None, + scipy_args: dict[str, Any] | None = None, skip_paramdomain_outside_edge_test: bool = False, ) -> None: """ @@ -410,7 +410,7 @@ def check_logcdf( domain: Domain, paramdomains: dict[str, Domain], scipy_logcdf: Callable, - decimal: Optional[int] = None, + decimal: int | None = None, n_samples: int = 100, skip_paramdomain_inside_edge_test: bool = False, skip_paramdomain_outside_edge_test: bool = False, @@ -524,7 +524,7 @@ def check_icdf( paramdomains: dict[str, Domain], scipy_icdf: Callable, skip_paramdomain_outside_edge_test=False, - decimal: Optional[int] = None, + decimal: int | None = None, n_samples: int = 100, ) -> None: """ @@ -619,7 +619,7 @@ def check_selfconsistency_discrete_logcdf( distribution: Distribution, domain: Domain, paramdomains: dict[str, Domain], - decimal: Optional[int] = None, + decimal: int | None = None, n_samples: int = 100, ) -> None: """ @@ -842,17 +842,17 @@ class BaseTestDistributionRandom: """ - pymc_dist: Optional[Callable] = None - pymc_dist_params: Optional[dict] = None - reference_dist: Optional[Callable] = None - reference_dist_params: Optional[dict] = None - expected_rv_op_params: Optional[dict] = None + pymc_dist: Callable | None = None + pymc_dist_params: dict | None = None + reference_dist: Callable | None = None + reference_dist_params: dict | None = None + expected_rv_op_params: dict | None = None checks_to_run: list[str] = [] size = 15 decimal = select_by_precision(float64=6, float32=3) - sizes_to_check: Optional[list] = None - sizes_expected: Optional[list] = None + sizes_to_check: list | None = None + sizes_expected: list | None = None repeated_params_shape = 5 random_state = None diff --git a/pymc/tuning/starting.py b/pymc/tuning/starting.py index 90f56d19be..129d6f8973 100644 --- a/pymc/tuning/starting.py +++ b/pymc/tuning/starting.py @@ -22,7 +22,6 @@ import warnings from collections.abc import Sequence -from typing import Optional import numpy as np import pytensor.gradient as tg @@ -46,7 +45,7 @@ def find_MAP( start=None, - vars: Optional[Sequence[Variable]] = None, + vars: Sequence[Variable] | None = None, method="L-BFGS-B", return_raw=False, include_transformed=True, @@ -55,7 +54,7 @@ def find_MAP( maxeval=5000, model=None, *args, - seed: Optional[int] = None, + seed: int | None = None, **kwargs, ): """Finds the local maximum a posteriori point given a model. diff --git a/pymc/util.py b/pymc/util.py index b72e17e0ae..a3f45e889c 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -16,7 +16,7 @@ import warnings from collections.abc import Sequence -from typing import Any, NewType, Optional, Union, cast +from typing import Any, NewType, cast import arviz import cloudpickle @@ -248,7 +248,7 @@ def enhanced(*args, **kwargs): def dataset_to_point_list( - ds: Union[xarray.Dataset, dict[str, xarray.DataArray]], sample_dims: Sequence[str] + ds: xarray.Dataset | dict[str, xarray.DataArray], sample_dims: Sequence[str] ) -> tuple[list[dict[str, np.ndarray]], dict[str, Any]]: # All keys of the dataset must be a str var_names = cast(list[str], list(ds.keys())) @@ -284,7 +284,7 @@ def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData: return nidata -def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> tuple[int, int]: +def chains_and_samples(data: xarray.Dataset | arviz.InferenceData) -> tuple[int, int]: """Extract and return number of chains and samples in xarray or arviz traces.""" dataset: xarray.Dataset if isinstance(data, xarray.Dataset): @@ -312,7 +312,7 @@ def hashable(a=None) -> int: # first hash the keys and values with hashable # then hash the tuple of int-tuples with the builtin return hash(tuple((hashable(k), hashable(v)) for k, v in a.items())) - if isinstance(a, (tuple, list)): + if isinstance(a, tuple | list): # lists are mutable and not hashable by default # for memoization, we need the hash to depend on the items return hash(tuple(hashable(i) for i in a)) @@ -405,14 +405,14 @@ def wrapped(**kwargs): return wrapped -RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]] -RandomState = Union[RandomSeed, np.random.RandomState, np.random.Generator] +RandomSeed = None | int | Sequence[int] | np.ndarray +RandomState = RandomSeed | np.random.RandomState | np.random.Generator def _get_seeds_per_chain( random_state: RandomState, chains: int, -) -> Union[Sequence[int], np.ndarray]: +) -> Sequence[int] | np.ndarray: """Obtain or validate specified integer seeds per chain. This function process different possible sources of seeding and returns one integer @@ -448,7 +448,7 @@ def _get_unique_seeds_per_chain(integers_fn): if isinstance(random_state, np.random.RandomState): return _get_unique_seeds_per_chain(random_state.randint) - if not isinstance(random_state, (list, tuple, np.ndarray)): + if not isinstance(random_state, list | tuple | np.ndarray): raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.") if len(random_state) != chains: @@ -459,9 +459,7 @@ def _get_unique_seeds_per_chain(integers_fn): return random_state -def get_value_vars_from_user_vars( - vars: Union[Variable, Sequence[Variable]], model -) -> list[Variable]: +def get_value_vars_from_user_vars(vars: Variable | Sequence[Variable], model) -> list[Variable]: """Converts user "vars" input into value variables. More often than not, users will pass random variables, and we will extract the @@ -527,7 +525,7 @@ def _add_future_warning_tag(var) -> None: def makeiter(a): - if isinstance(a, (tuple, list)): + if isinstance(a, tuple | list): return a else: return [a] diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index feb0a3a925..d2ff5df970 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional import numpy as np import pytensor @@ -333,7 +332,7 @@ def sample_approx(approx, draws=100, include_transformed=True): class SingleGroupApproximation(Approximation): """Base class for Single Group Approximation""" - _group_class: Optional[type] = None + _group_class: type | None = None def __init__(self, *args, **kwargs): groups = [self._group_class(None, *args, **kwargs)] diff --git a/pymc/variational/callbacks.py b/pymc/variational/callbacks.py index 3c5313deb7..3c911e1ba2 100644 --- a/pymc/variational/callbacks.py +++ b/pymc/variational/callbacks.py @@ -14,7 +14,7 @@ import collections -from typing import Callable +from collections.abc import Callable import numpy as np diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py index 5b6539a9e0..435cac9fbb 100644 --- a/pymc/variational/minibatch_rv.py +++ b/pymc/variational/minibatch_rv.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Union, cast +from typing import Any, cast import pytensor.tensor as pt @@ -51,7 +51,7 @@ def perform(self, node, inputs, output_storage): def create_minibatch_rv( rv: TensorVariable, - total_size: Union[int, None, Sequence[Union[int, EllipsisType, None]]], + total_size: int | None | Sequence[int | EllipsisType | None], ) -> TensorVariable: """Create variable whose logp is rescaled by total_size.""" if isinstance(total_size, int): @@ -60,7 +60,7 @@ def create_minibatch_rv( else: missing_ndims = rv.ndim - 1 total_size = [total_size] + [None] * missing_ndims - elif isinstance(total_size, (list, tuple)): + elif isinstance(total_size, list | tuple): total_size = list(total_size) if Ellipsis in total_size: # Replace Ellipsis by None diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index f605f01fd4..35f924c1a7 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -122,7 +122,7 @@ def _known_scan_ignored_inputs(terms): return [ n.owner.inputs[0] for n in pytensor.graph.ancestors(terms) - if n.owner is not None and isinstance(n.owner.op, (MinibatchIndexRV, SimulatorRV)) + if n.owner is not None and isinstance(n.owner.op, MinibatchIndexRV | SimulatorRV) ] @@ -163,9 +163,9 @@ def try_to_set_test_value(node_in, node_out, s): if s is None: s = 1 s = pytensor.compile.view_op(pt.as_tensor(s)) - if not isinstance(node_in, (list, tuple)): + if not isinstance(node_in, list | tuple): node_in = [node_in] - if not isinstance(node_out, (list, tuple)): + if not isinstance(node_out, list | tuple): node_out = [node_out] for i, o in zip(node_in, node_out): if hasattr(i.tag, "test_value"): @@ -1482,10 +1482,10 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No node_in = node if more_replacements: node = graph_replace(node, more_replacements, strict=False) - if not isinstance(node, (list, tuple)): + if not isinstance(node, list | tuple): node = [node] node = self.model.replace_rvs_by_values(node) - if not isinstance(node_in, (list, tuple)): + if not isinstance(node_in, list | tuple): node = node[0] if size is None: node_out = self.symbolic_single_sample(node) diff --git a/pyproject.toml b/pyproject.toml index 417178a516..83afadc67b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ addopts = ["--color=yes"] [tool.ruff] line-length = 100 -target-version = "py39" +target-version = "py310" exclude = ["versioneer.py"] [tool.ruff.lint] diff --git a/requirements-dev.txt b/requirements-dev.txt index 56077f3a6b..c5344b0b3d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,7 +17,7 @@ numpydoc pandas>=0.24.0 polyagamma pre-commit>=2.8.0 -pytensor>=2.19,<2.20 +pytensor>=2.20,<2.21 pytest-cov>=2.5 pytest>=3.0 rich>=13.7.1 diff --git a/requirements.txt b/requirements.txt index 370dcbd41e..50bbd8ae8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ cachetools>=4.2.1 cloudpickle numpy>=1.15.0 pandas>=0.24.0 -pytensor>=2.19,<2.20 +pytensor>=2.20,<2.21 rich>=13.7.1 scipy>=1.4.1 typing-extensions>=3.7.4 diff --git a/setup.py b/setup.py index 2b5e03dd2c..7369c7ba3a 100755 --- a/setup.py +++ b/setup.py @@ -30,7 +30,6 @@ "Development Status :: 5 - Production/Stable", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -71,7 +70,7 @@ # Also see MANIFEST.in # package_data={'docs': ['*']}, classifiers=classifiers, - python_requires=">=3.9", + python_requires=">=3.10", install_requires=install_reqs, tests_require=test_reqs, ) diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index adce0c6f5b..55eb00cc1d 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -372,7 +372,7 @@ def test_mv_missing_data_model(self): ) # make sure that data is really missing - assert isinstance(y.owner.inputs[0].owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)) + assert isinstance(y.owner.inputs[0].owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) test_dict = { "posterior": ["mu", "chol_cov"], diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index 8f66db6b17..4996ff63ec 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -13,6 +13,7 @@ # limitations under the License. import functools as ft +import itertools import sys import warnings @@ -76,7 +77,7 @@ def invlogit(x, eps=sys.float_info.epsilon): def orderedlogistic_logpdf(value, eta, cutpoints): c = np.concatenate(([-np.inf], cutpoints, [np.inf])) - ps = np.array([invlogit(eta - cc) - invlogit(eta - cc1) for cc, cc1 in zip(c[:-1], c[1:])]) + ps = np.array([invlogit(eta - cc) - invlogit(eta - cc1) for cc, cc1 in itertools.pairwise(c)]) p = ps[value] return np.where(np.all(ps >= 0), np.log(p), -np.inf) @@ -87,7 +88,7 @@ def invprobit(x): def orderedprobit_logpdf(value, eta, cutpoints): c = np.concatenate(([-np.inf], cutpoints, [np.inf])) - ps = np.array([invprobit(eta - cc) - invprobit(eta - cc1) for cc, cc1 in zip(c[:-1], c[1:])]) + ps = np.array([invprobit(eta - cc) - invprobit(eta - cc1) for cc, cc1 in itertools.pairwise(c)]) p = ps[value] return np.where(np.all(ps >= 0), np.log(p), -np.inf) diff --git a/tests/distributions/test_simulator.py b/tests/distributions/test_simulator.py index 1ea201c39f..928582968a 100644 --- a/tests/distributions/test_simulator.py +++ b/tests/distributions/test_simulator.py @@ -257,7 +257,7 @@ def test_upstream_rngs_not_in_compiled_logp(self, seeded_test): shared_rng_vars = [ node for node in ancestors(compiled_graph) - if isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) + if isinstance(node, RandomStateSharedVariable | RandomGeneratorSharedVariable) ] assert len(shared_rng_vars) == 1 diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index 0ab40828a2..c2f9635fa9 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -301,7 +301,7 @@ def test_joint_logp_incsubtensor(indices, size): a_idx = pt.set_subtensor(a[indices], data) - assert isinstance(a_idx.owner.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)) + assert isinstance(a_idx.owner.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1) a_idx_value_var = a_idx.type() a_idx_value_var.name = "a_idx_value" diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index b3e5c5656e..fa0c53831e 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -1136,7 +1136,7 @@ def test_joint_logprob_subtensor(): # (e.g., at least one of the advanced indexes has non-repeating values) A_idx = A_rv[I_rv, pt.ogrid[A_rv.shape[-1] :]] - assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1)) + assert isinstance(A_idx.owner.op, Subtensor | AdvancedSubtensor | AdvancedSubtensor1) A_idx_value_var = A_idx.type() A_idx_value_var.name = "A_idx_value" diff --git a/tests/logprob/test_rewriting.py b/tests/logprob/test_rewriting.py index 66c28b102d..fd1747d43d 100644 --- a/tests/logprob/test_rewriting.py +++ b/tests/logprob/test_rewriting.py @@ -131,7 +131,7 @@ def test_joint_logprob_incsubtensor(indices, size): y_value_var = Y_rv.clone() y_value_var.name = "y" - assert isinstance(Y_rv.owner.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)) + assert isinstance(Y_rv.owner.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1) Y_rv_logp = conditional_logp({Y_rv: y_value_var}) Y_rv_logp_combined = pt.add(*Y_rv_logp.values()) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 5b53f98642..1d9f68c267 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -15,7 +15,8 @@ import re import warnings -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from unittest import mock import arviz as az @@ -297,9 +298,9 @@ def test_idata_kwargs( model_test_idata_kwargs: pm.Model, sampler: Callable[..., az.InferenceData], idata_kwargs: dict[str, Any], - postprocessing_backend: Optional[str], + postprocessing_backend: str | None, ): - idata: Optional[az.InferenceData] = None + idata: az.InferenceData | None = None with model_test_idata_kwargs: idata = sampler( tune=50, diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index fb95b5b3fb..05b8b99e47 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -274,7 +274,7 @@ def test_inputs(self): for child, parents_in_plot in self.compute_graph.items(): var = self.model[child] parents_in_graph = self.model_graph.get_parent_names(var) - if isinstance(var, (SharedVariable, TensorConstant)): + if isinstance(var, SharedVariable | TensorConstant): # observed data also doesn't have parents in the compute graph! # But for the visualization we like them to become descendants of the # RVs that these observations belong to. diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 29f53808e8..9dcaaf94c3 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -152,7 +152,7 @@ def test_extract_obs_data(): constant = pt.as_tensor(data_m.filled()) z_at = pt.set_subtensor(constant[mask.nonzero()], missing_values) - assert isinstance(z_at.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)) + assert isinstance(z_at.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) res = extract_obs_data(z_at) @@ -169,7 +169,7 @@ def test_extract_obs_data(): constant = pt.as_tensor(data_m.filled()) z_at = pt.set_subtensor(constant[mask.nonzero()], missing_values) - assert isinstance(z_at.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)) + assert isinstance(z_at.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) res = extract_obs_data(z_at)