Skip to content

Remove MvNormalSVD Class #432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from pymc_extras.statespace.filters.distributions import (
LinearGaussianStateSpace,
MvNormalSVD,
SequenceMvNormal,
)
from pymc_extras.statespace.filters.utilities import stabilize
Expand Down Expand Up @@ -2233,7 +2232,9 @@ def impulse_response_function(
if shock_trajectory is None:
shock_trajectory = pt.zeros((n_steps, self.k_posdef))
if Q is not None:
init_shock = MvNormalSVD("initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM])
init_shock = pm.MvNormal(
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
)
else:
init_shock = pm.Deterministic(
"initial_shock",
Expand Down
33 changes: 9 additions & 24 deletions pymc_extras/statespace/filters/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from pymc import intX
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
from pymc.distributions.multivariate import MvNormal
from pymc.distributions.shape_utils import get_support_shape_1d
from pymc.logprob.abstract import _logprob
from pytensor.graph.basic import Node
from pytensor.tensor.random.basic import MvNormalRV

floatX = pytensor.config.floatX
COV_ZERO_TOL = 0
Expand Down Expand Up @@ -49,23 +47,6 @@ def make_signature(sequence_names):
return f"{signature},[rng]->[rng],({time},{state_and_obs})"


class MvNormalSVDRV(MvNormalRV):
name = "multivariate_normal"
signature = "(n),(n,n)->(n)"
dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")


class MvNormalSVD(MvNormal):
"""Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".

A JAX MvNormal robust to low-rank covariance matrices
"""

# TODO: Remove this entirely on next PyMC release; method will be exposed directly in MvNormal
rv_op = MvNormalSVDRV(method="svd")


class LinearGaussianStateSpaceRV(SymbolicRandomVariable):
default_output = 1
_print_name = ("LinearGuassianStateSpace", "\\operatorname{LinearGuassianStateSpace}")
Expand Down Expand Up @@ -223,8 +204,12 @@ def step_fn(*args):
k = T.shape[0]
a = state[:k]

middle_rng, a_innovation = MvNormalSVD.dist(mu=0, cov=Q, rng=rng).owner.outputs
next_rng, y_innovation = MvNormalSVD.dist(mu=0, cov=H, rng=middle_rng).owner.outputs
middle_rng, a_innovation = pm.MvNormal.dist(
mu=0, cov=Q, rng=rng, method="svd"
).owner.outputs
next_rng, y_innovation = pm.MvNormal.dist(
mu=0, cov=H, rng=middle_rng, method="svd"
).owner.outputs

a_mu = c + T @ a
a_next = a_mu + R @ a_innovation
Expand All @@ -239,8 +224,8 @@ def step_fn(*args):
Z_init = Z_ if Z_ in non_sequences else Z_[0]
H_init = H_ if H_ in non_sequences else H_[0]

init_x_ = MvNormalSVD.dist(a0_, P0_, rng=rng)
init_y_ = MvNormalSVD.dist(Z_init @ init_x_, H_init, rng=rng)
init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method="svd")
init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method="svd")

init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)

Expand Down Expand Up @@ -400,7 +385,7 @@ def rv_op(cls, mus, covs, logp, size=None):
rng = pytensor.shared(np.random.default_rng())

def step(mu, cov, rng):
new_rng, mvn = MvNormalSVD.dist(mu=mu, cov=cov, rng=rng).owner.outputs
new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs
return mvn, {rng: new_rng}

mvn_seq, updates = pytensor.scan(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ filterwarnings =[
# Warning coming from blackjax
'ignore:jax\.tree_map is deprecated:DeprecationWarning',

# Ignore PyMC use of numpy.core
# PyMC uses numpy.core functions, which emits an warning as of numpy>2.0
'ignore:numpy\.core\.numeric is deprecated:DeprecationWarning',
]

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pymc>=5.21
pymc>=5.21.1
scikit-learn
better-optimize