Skip to content

Commit 6d38385

Browse files
Remove MvNormalSVD Class (#432)
* Improve numpy.core deprecation warning comment * Remove MvNormalSVD and replace with method argument * Require pymc>=5.21.1 for the `method` argument in MvNormal
1 parent 49e2818 commit 6d38385

File tree

4 files changed

+14
-28
lines changed

4 files changed

+14
-28
lines changed

pymc_extras/statespace/core/statespace.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
)
2929
from pymc_extras.statespace.filters.distributions import (
3030
LinearGaussianStateSpace,
31-
MvNormalSVD,
3231
SequenceMvNormal,
3332
)
3433
from pymc_extras.statespace.filters.utilities import stabilize
@@ -2233,7 +2232,9 @@ def impulse_response_function(
22332232
if shock_trajectory is None:
22342233
shock_trajectory = pt.zeros((n_steps, self.k_posdef))
22352234
if Q is not None:
2236-
init_shock = MvNormalSVD("initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM])
2235+
init_shock = pm.MvNormal(
2236+
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
2237+
)
22372238
else:
22382239
init_shock = pm.Deterministic(
22392240
"initial_shock",

pymc_extras/statespace/filters/distributions.py

+9-24
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
from pymc import intX
77
from pymc.distributions.dist_math import check_parameters
88
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
9-
from pymc.distributions.multivariate import MvNormal
109
from pymc.distributions.shape_utils import get_support_shape_1d
1110
from pymc.logprob.abstract import _logprob
1211
from pytensor.graph.basic import Node
13-
from pytensor.tensor.random.basic import MvNormalRV
1412

1513
floatX = pytensor.config.floatX
1614
COV_ZERO_TOL = 0
@@ -49,23 +47,6 @@ def make_signature(sequence_names):
4947
return f"{signature},[rng]->[rng],({time},{state_and_obs})"
5048

5149

52-
class MvNormalSVDRV(MvNormalRV):
53-
name = "multivariate_normal"
54-
signature = "(n),(n,n)->(n)"
55-
dtype = "floatX"
56-
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
57-
58-
59-
class MvNormalSVD(MvNormal):
60-
"""Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
61-
62-
A JAX MvNormal robust to low-rank covariance matrices
63-
"""
64-
65-
# TODO: Remove this entirely on next PyMC release; method will be exposed directly in MvNormal
66-
rv_op = MvNormalSVDRV(method="svd")
67-
68-
6950
class LinearGaussianStateSpaceRV(SymbolicRandomVariable):
7051
default_output = 1
7152
_print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}")
@@ -223,8 +204,12 @@ def step_fn(*args):
223204
k = T.shape[0]
224205
a = state[:k]
225206

226-
middle_rng, a_innovation = MvNormalSVD.dist(mu=0, cov=Q, rng=rng).owner.outputs
227-
next_rng, y_innovation = MvNormalSVD.dist(mu=0, cov=H, rng=middle_rng).owner.outputs
207+
middle_rng, a_innovation = pm.MvNormal.dist(
208+
mu=0, cov=Q, rng=rng, method="svd"
209+
).owner.outputs
210+
next_rng, y_innovation = pm.MvNormal.dist(
211+
mu=0, cov=H, rng=middle_rng, method="svd"
212+
).owner.outputs
228213

229214
a_mu = c + T @ a
230215
a_next = a_mu + R @ a_innovation
@@ -239,8 +224,8 @@ def step_fn(*args):
239224
Z_init = Z_ if Z_ in non_sequences else Z_[0]
240225
H_init = H_ if H_ in non_sequences else H_[0]
241226

242-
init_x_ = MvNormalSVD.dist(a0_, P0_, rng=rng)
243-
init_y_ = MvNormalSVD.dist(Z_init @ init_x_, H_init, rng=rng)
227+
init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method="svd")
228+
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method="svd")
244229

245230
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)
246231

@@ -400,7 +385,7 @@ def rv_op(cls, mus, covs, logp, size=None):
400385
rng = pytensor.shared(np.random.default_rng())
401386

402387
def step(mu, cov, rng):
403-
new_rng, mvn = MvNormalSVD.dist(mu=mu, cov=cov, rng=rng).owner.outputs
388+
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs
404389
return mvn, {rng: new_rng}
405390

406391
mvn_seq, updates = pytensor.scan(

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ filterwarnings =[
2121
# Warning coming from blackjax
2222
'ignore:jax\.tree_map is deprecated:DeprecationWarning',
2323

24-
# Ignore PyMC use of numpy.core
24+
# PyMC uses numpy.core functions, which emits an warning as of numpy>2.0
2525
'ignore:numpy\.core\.numeric is deprecated:DeprecationWarning',
2626
]
2727

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
pymc>=5.21
1+
pymc>=5.21.1
22
scikit-learn
33
better-optimize

0 commit comments

Comments
 (0)