diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4fe668ee0b..8d33a7b898 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -358,12 +358,12 @@ jobs: - name: Cache conda uses: actions/cache@v3 env: - # Increase this value to reset cache if environment-test.yml has not changed + # Increase this value to reset cache if environment-jax.yml has not changed CACHE_NUMBER: 0 with: path: ~/conda_pkgs_dir key: ${{ runner.os }}-py${{matrix.python-version}}-conda-${{ env.CACHE_NUMBER }}-${{ - hashFiles('conda-envs/environment-test.yml') }} + hashFiles('conda-envs/environment-jax.yml') }} - name: Cache multiple paths uses: actions/cache@v3 env: @@ -383,7 +383,7 @@ jobs: mamba-version: "*" activate-environment: pymc-test channel-priority: strict - environment-file: conda-envs/environment-test.yml + environment-file: conda-envs/environment-jax.yml python-version: ${{matrix.python-version}} use-mamba: true use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267 @@ -392,10 +392,6 @@ jobs: conda activate pymc-test pip install -e . python --version - - name: Install external samplers - run: | - conda activate pymc-test - pip install "numpyro>=0.8.0" "blackjax>=1.0.0" - name: Run tests run: | python -m pytest -vv --cov=pymc --cov-report=xml --no-cov-on-fail --cov-report term --durations=50 $TEST_SUBSET diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index b9a533a464..c57e6bb7ff 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -36,7 +36,7 @@ dependencies: - watermark - polyagamma - sphinx-remove-toctrees -- mypy=0.990 +- mypy=1.5.1 - types-cachetools - pip: - git+https://github.com/pymc-devs/pymc-sphinx-theme diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml new file mode 100644 index 0000000000..542d0ae27d --- /dev/null +++ b/conda-envs/environment-jax.yml @@ -0,0 +1,38 @@ +# "test" conda envs are used to set up our CI environment in GitHub actions +name: pymc-test +channels: +- conda-forge +- defaults +dependencies: +# Base dependencies +- arviz>=0.13.0 +- blas +- cachetools>=4.2.1 +- cloudpickle +- fastprogress>=0.2.0 +- h5py>=2.7 +# Jaxlib version must not be greater than jax version! +- blackjax>=1.0.0 +- jaxlib==0.4.14 +- jax==0.4.16 +- libblas=*=*mkl +- mkl-service +- numpy>=1.15.0 +- numpyro>=0.8.0 +- pandas>=0.24.0 +- pip +- pytensor>=2.17.0,<2.18 +- python-graphviz +- networkx +- scipy>=1.4.1 +- typing-extensions>=3.7.4 +# Extra dependencies for testing +- ipython>=7.16 +- pre-commit>=2.8.0 +- pytest-cov>=2.5 +- pytest>=3.0 +- mypy=1.5.1 +- types-cachetools +- pip: + - numdifftools>=0.9.40 + - mcbackend>=0.4.0 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 01ad9c4f31..ee9885d170 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -27,7 +27,7 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy=0.990 +- mypy=1.5.1 - types-cachetools - pip: - numdifftools>=0.9.40 diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 2c86d9af49..a6d1ce27f5 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -33,7 +33,7 @@ dependencies: - sphinx>=1.5 - watermark - sphinx-remove-toctrees -- mypy=0.990 +- mypy=1.5.1 - types-cachetools - pip: - git+https://github.com/pymc-devs/pymc-sphinx-theme diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 9a0699752f..32905a49ab 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -27,7 +27,7 @@ dependencies: - pre-commit>=2.8.0 - pytest-cov>=2.5 - pytest>=3.0 -- mypy=0.990 +- mypy=1.5.1 - types-cachetools - pip: - numdifftools>=0.9.40 diff --git a/pymc/gp/hsgp_approx.py b/pymc/gp/hsgp_approx.py index adca16f600..0e13c77e25 100644 --- a/pymc/gp/hsgp_approx.py +++ b/pymc/gp/hsgp_approx.py @@ -189,7 +189,9 @@ def __init__( self._drop_first = drop_first self._m = m self._m_star = int(np.prod(self._m)) - self._L = L + self._L: Optional[pt.TensorVariable] = None + if L is not None: + self._L = pt.as_tensor(L) self._c = c super().__init__(mean_func=mean_func, cov_func=cov_func) @@ -198,13 +200,13 @@ def __add__(self, other): raise NotImplementedError("Additive HSGPs aren't supported.") @property - def L(self): + def L(self) -> pt.TensorVariable: if self._L is None: raise RuntimeError("Boundaries `L` required but still unset.") return self._L @L.setter - def L(self, value): + def L(self, value: TensorLike): self._L = pt.as_tensor_variable(value) def prior_linearized(self, Xs: TensorLike): @@ -290,9 +292,7 @@ def prior_linearized(self, Xs: TensorLike): # If not provided, use Xs and c to set L if self._L is None: assert isinstance(self._c, (numbers.Real, np.ndarray, pt.TensorVariable)) - self.L = set_boundary(Xs, self._c) - else: - self.L = self._L + self._L = pt.as_tensor(set_boundary(Xs, self._c)) eigvals = calc_eigenvalues(self.L, self._m, tl=pt) phi = calc_eigenvectors(Xs, self.L, eigvals, self._m, tl=pt) diff --git a/pymc/model/core.py b/pymc/model/core.py index 01a655679a..9fe8626d3f 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -66,7 +66,7 @@ from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.utils import ParameterValueError -from pymc.model_graph import VarName, model_to_graphviz +from pymc.model_graph import model_to_graphviz from pymc.pytensorf import ( PointFunc, SeedSequenceSeed, @@ -80,6 +80,7 @@ ) from pymc.util import ( UNSET, + VarName, WithMemoization, _add_future_warning_tag, get_transformed_name, @@ -2061,7 +2062,7 @@ def compile_fn( ) -def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]: +def Point(*args, filter_model_vars=False, **kwargs) -> Dict[VarName, np.ndarray]: """Build a point. Uses same args as dict() does. Filters out variables not in the model. All keys are strings. diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 849156946d..154384fda7 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -54,6 +54,8 @@ def prune_vars_detached_from_observed(model: Model) -> Model: def parse_vars(model: Model, vars: Union[ModelVariable, Sequence[ModelVariable]]) -> List[Variable]: - if not isinstance(vars, (list, tuple)): - vars = (vars,) - return [model[var] if isinstance(var, str) else var for var in vars] + if isinstance(vars, (list, tuple)): + vars_seq = vars + else: + vars_seq = (vars,) + return [model[var] if isinstance(var, str) else var for var in vars_seq] diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 39d9360aea..5f998c2c1a 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -14,7 +14,7 @@ import warnings from collections import defaultdict -from typing import Dict, Iterable, List, NewType, Optional, Sequence, Set +from typing import Dict, Iterable, List, Optional, Sequence, Set from pytensor import function from pytensor.compile.sharedvalue import SharedVariable @@ -28,10 +28,7 @@ import pymc as pm -from pymc.util import get_default_varnames, get_var_name - -VarName = NewType("VarName", str) - +from pymc.util import VarName, get_default_varnames, get_var_name __all__ = ( "ModelGraph", @@ -76,12 +73,12 @@ def _expand(x): return reversed(_filter_non_parameter_inputs(x)) return [] - parents = { - VarName(get_var_name(x)) - for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand) + parents = set() + for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand): # Only consider nodes that are in the named model variables. - if x.name and x.name in self._all_var_names - } + vname = getattr(x, "name", None) + if isinstance(vname, str) and vname in self._all_var_names: + parents.add(VarName(vname)) return parents @@ -113,7 +110,7 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va selected_ancestors.add(self.model.rvs_to_values[var]) # ordering of self._all_var_names is important - return [VarName(var.name) for var in selected_ancestors] + return [get_var_name(var) for var in selected_ancestors] def make_compute_graph( self, var_names: Optional[Iterable[VarName]] = None diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index de6d598342..519960b7a5 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -336,7 +336,7 @@ def sample_blackjax_nuts( var_names: Optional[Sequence[str]] = None, keep_untransformed: bool = False, chain_method: str = "parallel", - postprocessing_backend: Literal["cpu", "gpu"] | None = None, + postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", idata_kwargs: Optional[Dict[str, Any]] = None, postprocessing_chunks=None, # deprecated @@ -546,7 +546,7 @@ def sample_numpyro_nuts( progressbar: bool = True, keep_untransformed: bool = False, chain_method: str = "parallel", - postprocessing_backend: Literal["cpu", "gpu"] | None = None, + postprocessing_backend: Optional[Literal["cpu", "gpu"]] = None, postprocessing_vectorize: Literal["vmap", "scan"] = "scan", idata_kwargs: Optional[Dict] = None, nuts_kwargs: Optional[Dict] = None, diff --git a/pymc/util.py b/pymc/util.py index 1ec58d4cee..8271fd612e 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -15,7 +15,7 @@ import functools import warnings -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, NewType, Optional, Sequence, Tuple, Union, cast import arviz import cloudpickle @@ -29,6 +29,8 @@ from pymc.exceptions import BlockModelAccessError +VarName = NewType("VarName", str) + class _UnsetType: """Type for the `UNSET` object to make it look nice in `help(...)` outputs.""" @@ -207,9 +209,9 @@ def get_default_varnames(var_iterator, include_transformed): return [var for var in var_iterator if not is_transformed_name(get_var_name(var))] -def get_var_name(var) -> str: +def get_var_name(var) -> VarName: """Get an appropriate, plain variable name for a variable.""" - return str(getattr(var, "name", var)) + return VarName(str(getattr(var, "name", var))) def get_transformed(z): diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index f81a6d6a23..99261f026e 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -51,7 +51,7 @@ import itertools import warnings -from typing import Any +from typing import Any, overload import numpy as np import pytensor @@ -980,17 +980,29 @@ def symbolic_random(self): """ raise NotImplementedError - @pytensor.config.change_flags(compute_test_value="off") + @overload + def set_size_and_deterministic( + self, node: Variable, s, d: bool, more_replacements: dict | None = None + ) -> Variable: + ... + + @overload def set_size_and_deterministic( - self, node: Variable, s, d: bool, more_replacements: dict = None + self, node: list[Variable], s, d: bool, more_replacements: dict | None = None ) -> list[Variable]: + ... + + @pytensor.config.change_flags(compute_test_value="off") + def set_size_and_deterministic( + self, node: Variable | list[Variable], s, d: bool, more_replacements: dict | None = None + ) -> Variable | list[Variable]: """*Dev* - after node is sampled via :func:`symbolic_sample_over_posterior` or :func:`symbolic_single_sample` new random generator can be allocated and applied to node Parameters ---------- - node: :class:`Variable` - PyTensor node with symbolically applied VI replacements + node + PyTensor node(s) with symbolically applied VI replacements s: scalar desired number of samples d: bool or int @@ -1000,7 +1012,7 @@ def set_size_and_deterministic( Returns ------- - :class:`Variable` with applied replacements, ready to use + :class:`Variable` or list with applied replacements, ready to use """ flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements) diff --git a/requirements-dev.txt b/requirements-dev.txt index 18eede6a60..f99b1484f3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,7 +10,7 @@ h5py>=2.7 ipython>=7.16 jupyter-sphinx mcbackend>=0.4.0 -mypy==0.990 +mypy==1.5.1 myst-nb numdifftools>=0.9.40 numpy>=1.15.0 diff --git a/scripts/generate_pip_deps_from_conda.py b/scripts/generate_pip_deps_from_conda.py index cbdc7791fa..69bdbb49f2 100755 --- a/scripts/generate_pip_deps_from_conda.py +++ b/scripts/generate_pip_deps_from_conda.py @@ -54,6 +54,7 @@ "networkx", "blas", "jax", + "jaxlib", } RENAME = {} diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 77da8e8a0f..1a93f8a6a7 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -41,13 +41,10 @@ pymc/logprob/utils.py pymc/model/core.py pymc/model/fgraph.py -pymc/model/transform/basic.py pymc/model/transform/conditioning.py -pymc/model_graph.py pymc/printing.py pymc/pytensorf.py pymc/sampling/jax.py -pymc/variational/opvi.py """ @@ -105,7 +102,6 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]): Exits the process with non-zero exit code upon unexpected results. """ df = mypy_to_pandas(mypy_lines) - all_files = { str(fp).replace(str(DP_ROOT), "").strip(os.sep).replace(os.sep, "/") for fp in DP_ROOT.glob("pymc/**/*.py")