Skip to content

Commit d14cb3a

Browse files
Apply suggestions from code review
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent f6f97e1 commit d14cb3a

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

pymc/distributions/timeseries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
690690
init_dist = change_dist_size(init_dist, batch_size)
691691
# initial_vol = initial_vol * at.ones(batch_size)
692692

693-
# Create OpFromGraph representing random draws form AR process
693+
# Create OpFromGraph representing random draws from GARCH11 process
694694
# Variables with underscore suffix are dummy inputs into the OpFromGraph
695695
init_ = init_dist.type()
696696
initial_vol_ = initial_vol.type()
@@ -701,8 +701,7 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
701701

702702
noise_rng = aesara.shared(np.random.default_rng())
703703

704-
def step(*args):
705-
prev_y, prev_sigma, omega, alpha_1, beta_1, rng = args
704+
def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
706705
new_sigma = at.sqrt(
707706
omega + alpha_1 * at.square(prev_y) + beta_1 * at.square(prev_sigma)
708707
)
@@ -761,6 +760,7 @@ def volatility_update(x, vol, w, a, b):
761760
sequences=[value_dimswapped[:-1]],
762761
outputs_info=[initial_vol],
763762
non_sequences=[omega, alpha_1, beta_1],
763+
strict = True,
764764
)
765765
sigma_t = at.concatenate([[initial_vol], vol])
766766
# Compute and collapse logp across time dimension

pymc/tests/distributions/test_timeseries.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -533,10 +533,11 @@ def test_logp(self):
533533
np.testing.assert_allclose(garch_like, reg_like, 10 ** (-decimal))
534534

535535
@pytest.mark.parametrize(
536-
"arg_name",
536+
"batched_param",
537537
["omega", "alpha_1", "beta_1", "initial_vol"],
538538
)
539-
def test_batched_size(self, arg_name):
539+
@pytest.mark.parametrize("explicit_shape", (True, False))
540+
def test_batched_size(self, explicit_shape, batched_param):
540541
steps, batch_size = 100, 5
541542
param_val = np.square(np.random.randn(batch_size))
542543
init_kwargs = dict(

0 commit comments

Comments
 (0)