Skip to content

Commit 69e967e

Browse files
committed
Update logp calculation and tests
1 parent 74081b2 commit 69e967e

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

pymc/distributions/timeseries.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ def logp(
289289
"""
290290
Calculate log-probability of Gaussian Random Walk distribution at specified value.
291291
292+
293+
292294
Parameters
293295
----------
294296
x: numeric
@@ -307,17 +309,15 @@ def normal_logp(value, mu, sigma):
307309
)
308310
return logp
309311

310-
# Create logp calculation graph the inital time point
311-
init_logp = normal_logp(value[0] - init, mu, sigma)
312-
313-
# Create logp calculation graph for innovations
314-
stationary_vals = at.diff(value)
315-
innov_logp = normal_logp(stationary_vals, mu, sigma)
316-
317-
# Return both calculation logps in a vector
318-
total_logp = at.concatenate([init_logp, innov_logp])
312+
# Calculate initialization logp
313+
init_logp = normal_logp(at.expand_dims(value[0], 0), init, sigma)
319314

315+
# Make time series stationary around the mean value
316+
stationary_series = at.diff(value)
317+
series_logp = normal_logp(stationary_series, mu, sigma)
318+
total_logp = at.concatenate([init_logp, series_logp])
320319
total_logp = check_parameters(total_logp, sigma > 0, msg="sigma > 0")
320+
321321
return total_logp
322322

323323

pymc/tests/test_distributions_timeseries.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def test_grw_rv_op():
4444

4545

4646
def test_grw_log():
47-
vals = [1, 2]
48-
mu = 0
47+
vals = [0, 1, 2]
48+
mu = 1
4949
sd = 1
5050
init = 0
5151

@@ -58,7 +58,21 @@ def test_grw_log():
5858

5959
logp = pm.logp(grw, vals)
6060

61-
assert logp
61+
logp_vals = logp.eval()
62+
63+
# Calculate logp from scipy
64+
from scipy import stats
65+
66+
# Calculate logp in explicit loop for testing obviousness
67+
init_val = vals[0]
68+
init_logp = stats.norm(0, 1).logpdf(init_val)
69+
70+
logp_reference = [init_logp]
71+
for x_minus_one_val, x_val in zip(vals, vals[1:]):
72+
logp_point = stats.norm(x_minus_one_val + mu, sd).logpdf(x_val)
73+
logp_reference.append(logp_point)
74+
75+
np.testing.assert_almost_equal(logp_vals, logp_reference)
6276

6377

6478
@pytest.mark.xfail(reason="Timeseries not refactored")

0 commit comments

Comments
 (0)