Skip to content

Commit fd31103

Browse files
ferrinetwiecki
authored andcommitted
Raise NotImplementedError in not yet refactored timeseries distributions
Co-authored-by: Thomas Wiecki <[email protected]>
1 parent 5703a9d commit fd31103

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

pymc/distributions/timeseries.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,13 @@ class GARCH11(distribution.Continuous):
648648
initial_vol >= 0, initial volatility, sigma_0
649649
"""
650650

651+
def __new__(cls, *args, **kwargs):
652+
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
653+
654+
@classmethod
655+
def dist(cls, *args, **kwargs):
656+
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
657+
651658
def __init__(self, omega, alpha_1, beta_1, initial_vol, *args, **kwargs):
652659
super().__init__(*args, **kwargs)
653660

@@ -705,6 +712,13 @@ class EulerMaruyama(distribution.Continuous):
705712
parameters of the SDE, passed as ``*args`` to ``sde_fn``
706713
"""
707714

715+
def __new__(cls, *args, **kwargs):
716+
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
717+
718+
@classmethod
719+
def dist(cls, *args, **kwargs):
720+
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
721+
708722
def __init__(self, dt, sde_fn, sde_pars, *args, **kwds):
709723
super().__init__(*args, **kwds)
710724
self.dt = dt = at.as_tensor_variable(dt)
@@ -757,6 +771,13 @@ class MvGaussianRandomWalk(distribution.Continuous):
757771
758772
"""
759773

774+
def __new__(cls, *args, **kwargs):
775+
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
776+
777+
@classmethod
778+
def dist(cls, *args, **kwargs):
779+
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
780+
760781
def __init__(
761782
self, mu=0.0, cov=None, tau=None, chol=None, lower=True, init=None, *args, **kwargs
762783
):
@@ -879,6 +900,13 @@ class MvStudentTRandomWalk(MvGaussianRandomWalk):
879900
distribution for initial value (Defaults to Flat())
880901
"""
881902

903+
def __new__(cls, *args, **kwargs):
904+
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
905+
906+
@classmethod
907+
def dist(cls, *args, **kwargs):
908+
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
909+
882910
def __init__(self, nu, *args, **kwargs):
883911
super().__init__(*args, **kwargs)
884912
self.nu = at.as_tensor_variable(nu)

pymc/tests/test_distributions_timeseries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def test_moment(self, size, expected):
435435
assert_moment_is_expected(model, expected, check_finite_logp=False)
436436

437437

438-
@pytest.mark.xfail(reason="Timeseries not refactored")
438+
@pytest.mark.xfail(reason="Timeseries not refactored", raises=NotImplementedError)
439439
def test_GARCH11():
440440
# test data ~ N(0, 1)
441441
data = np.array(
@@ -496,7 +496,7 @@ def _gen_sde_path(sde, pars, dt, n, x0):
496496
return np.array(xs)
497497

498498

499-
@pytest.mark.xfail(reason="Timeseries not refactored")
499+
@pytest.mark.xfail(reason="Timeseries not refactored", raises=NotImplementedError)
500500
def test_linear():
501501
lam = -0.78
502502
sig2 = 5e-3

0 commit comments

Comments
 (0)