Skip to content

Commit 4fbbfeb

Browse files
Bump PyMC version requirement (#431)
* Use new `method` argument for `MvNormal` to defined `MvNormalSVD` * Can't use `if Variable` anymore * Bump PyMC version pin * Ignore numpy warning from pymc
1 parent ec46270 commit 4fbbfeb

File tree

6 files changed

+10
-28
lines changed

6 files changed

+10
-28
lines changed

conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
- conda-forge
44
- nodefaults
55
dependencies:
6-
- pymc>=5.20
6+
- pymc>=5.21
77
- pytest-cov>=2.5
88
- pytest>=3.0
99
- dask

conda-envs/windows-environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- numba<=0.60.0
13-
- pymc>=5.20
13+
- pymc>=5.21
1414
- pip:
1515
- blackjax
1616
- scikit-learn

pymc_extras/statespace/core/statespace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def _insert_random_variables(self):
707707
with pymc_model:
708708
for param_name in self.param_names:
709709
param = getattr(pymc_model, param_name, None)
710-
if param:
710+
if param is not None:
711711
found_params.append(param.name)
712712

713713
missing_params = list(set(self.param_names) - set(found_params))
@@ -746,7 +746,7 @@ def _insert_data_variables(self):
746746
with pymc_model:
747747
for data_name in data_names:
748748
data = getattr(pymc_model, data_name, None)
749-
if data:
749+
if data is not None:
750750
found_data.append(data.name)
751751

752752
missing_data = list(set(data_names) - set(found_data))

pymc_extras/statespace/filters/distributions.py

+2-23
Original file line numberDiff line numberDiff line change
@@ -62,29 +62,8 @@ class MvNormalSVD(MvNormal):
6262
A JAX MvNormal robust to low-rank covariance matrices
6363
"""
6464

65-
rv_op = MvNormalSVDRV()
66-
67-
68-
try:
69-
import jax.random
70-
71-
from pytensor.link.jax.dispatch.random import jax_sample_fn
72-
73-
@jax_sample_fn.register(MvNormalSVDRV)
74-
def jax_sample_fn_mvnormal_svd(op, node):
75-
def sample_fn(rng, size, dtype, *parameters):
76-
rng_key = rng["jax_state"]
77-
rng_key, sampling_key = jax.random.split(rng_key, 2)
78-
sample = jax.random.multivariate_normal(
79-
sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
80-
)
81-
rng["jax_state"] = rng_key
82-
return (rng, sample)
83-
84-
return sample_fn
85-
86-
except ImportError:
87-
pass
65+
# TODO: Remove this entirely on next PyMC release; method will be exposed directly in MvNormal
66+
rv_op = MvNormalSVDRV(method="svd")
8867

8968

9069
class LinearGaussianStateSpaceRV(SymbolicRandomVariable):

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ filterwarnings =[
2020

2121
# Warning coming from blackjax
2222
'ignore:jax\.tree_map is deprecated:DeprecationWarning',
23+
24+
# Ignore PyMC use of numpy.core
25+
'ignore:numpy\.core\.numeric is deprecated:DeprecationWarning',
2326
]
2427

2528
[tool.coverage.report]

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
pymc>=5.20
1+
pymc>=5.21
22
scikit-learn
33
better-optimize

0 commit comments

Comments
 (0)