Skip to content

Commit e15eb73

Browse files
committed
Add test
Implement momentum and fix some shape bug.
1 parent 2954ca1 commit e15eb73

File tree

2 files changed

+130
-55
lines changed

2 files changed

+130
-55
lines changed

pymc/distributions/timeseries.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from aeppl.abstract import _get_measurable_outputs
2323
from aeppl.logprob import _logprob
24-
from aesara import scan
2524
from aesara.graph import FunctionGraph, rewrite_graph
2625
from aesara.graph.basic import Node, clone_replace
2726
from aesara.raise_op import Assert
@@ -230,7 +229,7 @@ def random_walk_moment(op, rv, init_dist, innovation_dist, steps):
230229

231230
@_logprob.register(RandomWalkRV)
232231
def random_walk_logp(op, values, *inputs, **kwargs):
233-
# ALthough Aeppl can derive the logprob of random walks, it does not collapse
232+
# Although Aeppl can derive the logprob of random walks, it does not collapse
234233
# what PyMC considers the core dimension of steps. We do it manually here.
235234
(value,) = values
236235
# Recreate RV and obtain inner graph
@@ -681,15 +680,15 @@ def dist(cls, omega, alpha_1, beta_1, initial_vol, *, steps=None, **kwargs):
681680

682681
return super().dist([omega, alpha_1, beta_1, initial_vol, init_dist, steps], **kwargs)
683682

684-
685683
@classmethod
686684
def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None):
687685
if size is not None:
688686
batch_size = size
689687
else:
690688
# In this case the size of the init_dist depends on the parameters shape
691-
batch_size = at.broadcast_shape(omega, alpha_1, beta_1, init_dist)
689+
batch_size = at.broadcast_shape(omega, alpha_1, beta_1, initial_vol)
692690
init_dist = change_dist_size(init_dist, batch_size)
691+
# initial_vol = initial_vol * at.ones(batch_size)
693692

694693
# Create OpFromGraph representing random draws form AR process
695694
# Variables with underscore suffix are dummy inputs into the OpFromGraph
@@ -712,13 +711,16 @@ def step(*args):
712711

713712
(y_t, _), innov_updates_ = aesara.scan(
714713
fn=step,
715-
outputs_info=[init_, initial_vol_],
714+
outputs_info=[init_, initial_vol_ * at.ones(batch_size)],
716715
non_sequences=[omega_, alpha_1_, beta_1_, noise_rng],
717716
n_steps=steps_,
718717
strict=True,
719718
)
720719
(noise_next_rng,) = tuple(innov_updates_.values())
721-
garch11_ = at.concatenate([at.atleast_1d(init_), y_t.T], axis=-1)
720+
721+
garch11_ = at.concatenate([init_[None, ...], y_t], axis=0).dimshuffle(
722+
tuple(range(1, y_t.ndim)) + (0,)
723+
)
722724

723725
garch11_op = GARCH11RV(
724726
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_],
@@ -748,22 +750,30 @@ def garch11_logp(
748750
op, values, omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng, **kwargs
749751
):
750752
(value,) = values
753+
value_dimswapped = value.dimshuffle((value.ndim - 1,) + tuple(range(0, value.ndim - 1)))
754+
initial_vol = initial_vol * at.ones_like(value_dimswapped[0])
751755

752756
def volatility_update(x, vol, w, a, b):
753757
return at.sqrt(w + a * at.square(x) + b * at.square(vol))
754758

755-
vol, _ = scan(
759+
vol, _ = aesara.scan(
756760
fn=volatility_update,
757-
sequences=[value[:-1]],
761+
sequences=[value_dimswapped[:-1]],
758762
outputs_info=[initial_vol],
759763
non_sequences=[omega, alpha_1, beta_1],
760764
)
761765
sigma_t = at.concatenate([[initial_vol], vol])
762766
# Compute and collapse logp across time dimension
763-
innov_logp = at.sum(logp(Normal.dist(0, sigma_t), value), axis=-1)
767+
innov_logp = at.sum(logp(Normal.dist(0, sigma_t), value_dimswapped), axis=-1)
764768
return innov_logp
765769

766770

771+
@_moment.register(GARCH11RV)
772+
def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng):
773+
# GARCH(1,1) mean is zero
774+
return at.zeros_like(rv)
775+
776+
767777
class EulerMaruyama(distribution.Continuous):
768778
r"""
769779
Stochastic differential equation discretized with the Euler-Maruyama method.

pymc/tests/distributions/test_timeseries.py

Lines changed: 111 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -481,55 +481,120 @@ def test_change_dist_size(self):
481481
assert new_dist.eval().shape == (4, 3, 10)
482482

483483

484-
def test_GARCH11():
485-
# test data ~ N(0, 1)
486-
data = np.array(
484+
class TestGARCH11:
485+
def test_logp(self):
486+
# test data ~ N(0, 1)
487+
data = np.array(
488+
[
489+
-1.35078362,
490+
-0.81254164,
491+
0.28918551,
492+
-2.87043544,
493+
-0.94353337,
494+
0.83660719,
495+
-0.23336562,
496+
-0.58586298,
497+
-1.36856736,
498+
-1.60832975,
499+
-1.31403141,
500+
0.05446936,
501+
-0.97213128,
502+
-0.18928725,
503+
1.62011258,
504+
-0.95978616,
505+
-2.06536047,
506+
0.6556103,
507+
-0.27816645,
508+
-1.26413397,
509+
]
510+
)
511+
omega = 0.6
512+
alpha_1 = 0.4
513+
beta_1 = 0.5
514+
initial_vol = np.float64(0.9)
515+
vol = np.empty_like(data)
516+
vol[0] = initial_vol
517+
for i in range(len(data) - 1):
518+
vol[i + 1] = np.sqrt(omega + beta_1 * vol[i] ** 2 + alpha_1 * data[i] ** 2)
519+
520+
with Model() as t:
521+
y = GARCH11(
522+
"y",
523+
omega=omega,
524+
alpha_1=alpha_1,
525+
beta_1=beta_1,
526+
initial_vol=initial_vol,
527+
shape=data.shape,
528+
)
529+
z = Normal("z", mu=0, sigma=vol, shape=data.shape)
530+
garch_like = t.compile_logp(y)({"y": data})
531+
reg_like = t.compile_logp(z)({"z": data})
532+
decimal = select_by_precision(float64=7, float32=4)
533+
np.testing.assert_allclose(garch_like, reg_like, 10 ** (-decimal))
534+
535+
@pytest.mark.parametrize(
536+
"arg_name",
537+
["omega", "alpha_1", "beta_1", "initial_vol"],
538+
)
539+
def test_batched_size(self, arg_name):
540+
steps, batch_size = 100, 5
541+
param_val = np.square(np.random.randn(batch_size))
542+
init_kwargs = dict(
543+
omega=1.25,
544+
alpha_1=0.5,
545+
beta_1=0.45,
546+
initial_vol=2.5,
547+
)
548+
kwargs0 = init_kwargs.copy()
549+
kwargs0[arg_name] = init_kwargs[arg_name] * param_val
550+
with Model() as t0:
551+
y = GARCH11("y", shape=(batch_size, steps), **kwargs0)
552+
553+
y_eval = draw(y, draws=2)
554+
assert y_eval[0].shape == (batch_size, steps)
555+
assert not np.any(np.isclose(y_eval[0], y_eval[1]))
556+
557+
kwargs1 = init_kwargs.copy()
558+
with Model() as t1:
559+
for i in range(batch_size):
560+
kwargs1[arg_name] = init_kwargs[arg_name] * param_val[i]
561+
GARCH11(f"y_{i}", shape=steps, **kwargs1)
562+
563+
np.testing.assert_allclose(
564+
t0.compile_logp()(t0.initial_point()),
565+
t1.compile_logp()(t1.initial_point()),
566+
)
567+
568+
@pytest.mark.parametrize(
569+
"size, expected",
487570
[
488-
-1.35078362,
489-
-0.81254164,
490-
0.28918551,
491-
-2.87043544,
492-
-0.94353337,
493-
0.83660719,
494-
-0.23336562,
495-
-0.58586298,
496-
-1.36856736,
497-
-1.60832975,
498-
-1.31403141,
499-
0.05446936,
500-
-0.97213128,
501-
-0.18928725,
502-
1.62011258,
503-
-0.95978616,
504-
-2.06536047,
505-
0.6556103,
506-
-0.27816645,
507-
-1.26413397,
508-
]
571+
(None, np.zeros((2, 8))),
572+
((5, 2), np.zeros((5, 2, 8))),
573+
],
509574
)
510-
omega = 0.6
511-
alpha_1 = 0.4
512-
beta_1 = 0.5
513-
initial_vol = np.float64(0.9)
514-
vol = np.empty_like(data)
515-
vol[0] = initial_vol
516-
for i in range(len(data) - 1):
517-
vol[i + 1] = np.sqrt(omega + beta_1 * vol[i] ** 2 + alpha_1 * data[i] ** 2)
518-
519-
with Model() as t:
520-
y = GARCH11(
521-
"y",
522-
omega=omega,
523-
alpha_1=alpha_1,
524-
beta_1=beta_1,
525-
initial_vol=initial_vol,
526-
shape=data.shape,
575+
def test_moment(self, size, expected):
576+
with Model() as model:
577+
GARCH11(
578+
"x",
579+
omega=1.25,
580+
alpha_1=0.5,
581+
beta_1=0.45,
582+
initial_vol=np.ones(2),
583+
steps=7,
584+
size=size,
585+
)
586+
assert_moment_is_expected(model, expected, check_finite_logp=False)
587+
588+
def test_change_dist_size(self):
589+
base_dist = pm.GARCH11.dist(
590+
omega=1.25, alpha_1=0.5, beta_1=0.45, initial_vol=1.0, shape=(3, 10)
527591
)
528-
z = Normal("z", mu=0, sigma=vol, shape=data.shape)
529-
garch_like = t.compile_logp(y)({"y": data})
530-
reg_like = t.compile_logp(z)({"z": data})
531-
decimal = select_by_precision(float64=7, float32=4)
532-
np.testing.assert_allclose(garch_like, reg_like, 10 ** (-decimal))
592+
593+
new_dist = change_dist_size(base_dist, (4,))
594+
assert new_dist.eval().shape == (4, 10)
595+
596+
new_dist = change_dist_size(base_dist, (4,), expand=True)
597+
assert new_dist.eval().shape == (4, 3, 10)
533598

534599

535600
def _gen_sde_path(sde, pars, dt, n, x0):

0 commit comments

Comments
 (0)