Skip to content

Commit 614bb06

Browse files
committed
Obtain step information from dims and observed
1 parent 135fb37 commit 614bb06

File tree

2 files changed

+154
-35
lines changed

2 files changed

+154
-35
lines changed

pymc/distributions/timeseries.py

Lines changed: 78 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import warnings
1515

16-
from typing import Optional, Tuple, Union
16+
from typing import Any, Optional, Tuple, Union
1717

1818
import aesara
1919
import aesara.tensor as at
@@ -31,13 +31,20 @@
3131
from aesara.tensor.random.op import RandomVariable
3232
from aesara.tensor.random.utils import normalize_size_param
3333

34-
from pymc.aesaraf import change_rv_size, floatX, intX
34+
from pymc.aesaraf import change_rv_size, convert_observed_data, floatX, intX
3535
from pymc.distributions import distribution, multivariate
3636
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
3737
from pymc.distributions.dist_math import check_parameters
3838
from pymc.distributions.distribution import SymbolicDistribution, _moment, moment
3939
from pymc.distributions.logprob import ignore_logprob, logp
40-
from pymc.distributions.shape_utils import Shape, rv_size_is_none, to_tuple
40+
from pymc.distributions.shape_utils import (
41+
Dims,
42+
Shape,
43+
convert_dims,
44+
rv_size_is_none,
45+
to_tuple,
46+
)
47+
from pymc.model import modelcontext
4148
from pymc.util import check_dist_not_registered
4249

4350
__all__ = [
@@ -50,51 +57,61 @@
5057
]
5158

5259

53-
def get_steps_from_shape(
60+
def get_steps(
5461
steps: Optional[Union[int, np.ndarray, TensorVariable]],
55-
shape: Optional[Shape],
62+
*,
63+
shape: Optional[Shape] = None,
64+
dims: Optional[Dims] = None,
65+
observed: Optional[Any] = None,
5666
step_shape_offset: int = 0,
5767
):
58-
"""Extract number of steps from shape information
68+
"""Extract number of steps from shape / dims / observed information
5969
6070
Parameters
6171
----------
6272
steps:
6373
User specified steps for timeseries distribution
6474
shape:
6575
User specified shape for timeseries distribution
76+
dims:
77+
User specified dims for timeseries distribution
78+
observed:
79+
User specified observed data from timeseries distribution
6680
step_shape_offset:
6781
Difference between last shape dimension and number of steps in timeseries
6882
distribution, defaults to 0
6983
70-
Raises
71-
------
72-
ValueError
73-
If neither shape nor steps are provided
74-
7584
Returns
7685
-------
7786
steps
7887
Steps, if specified directly by user, or inferred from the last dimension of
79-
shape. When both steps and shape are provided, a symbolic Assert is added
80-
to make sure they are consistent.
88+
shape / dims / observed. When two sources of step information are provided,
89+
a symbolic Assert is added to ensure they are consistent.
8190
"""
82-
steps_from_shape = None
91+
inferred_steps = None
8392
if shape is not None:
8493
shape = to_tuple(shape)
8594
if shape[-1] is not ...:
86-
steps_from_shape = shape[-1] - step_shape_offset
87-
if steps is None:
88-
if steps_from_shape is not None:
89-
steps = steps_from_shape
90-
else:
91-
raise ValueError("Must specify steps or shape parameter")
92-
elif steps_from_shape is not None:
93-
# Assert that steps and shape are consistent
94-
steps = Assert(msg="Steps do not match last shape dimension")(
95-
steps, at.eq(steps, steps_from_shape)
95+
inferred_steps = shape[-1] - step_shape_offset
96+
97+
if inferred_steps is None and dims is not None:
98+
dims = convert_dims(dims)
99+
if dims[-1] is not ...:
100+
model = modelcontext(None)
101+
inferred_steps = model.dim_lengths[dims[-1]] - step_shape_offset
102+
103+
if inferred_steps is None and observed is not None:
104+
observed = convert_observed_data(observed)
105+
inferred_steps = observed.shape[-1] - step_shape_offset
106+
107+
if inferred_steps is None:
108+
inferred_steps = steps
109+
# If there are two sources of information for the steps, assert they are consistent
110+
elif steps is not None:
111+
inferred_steps = Assert(msg="Steps do not match last shape dimension")(
112+
inferred_steps, at.eq(inferred_steps, steps)
96113
)
97-
return steps
114+
return inferred_steps
98115

99116

100117
class GaussianRandomWalkRV(RandomVariable):
@@ -212,26 +229,38 @@ class GaussianRandomWalk(distribution.Continuous):
212229
213230
.. warning:: init will be cloned, rendering them independent of the ones passed as input.
214231
215-
steps : int
216-
Number of steps in Gaussian Random Walks (steps > 0).
232+
steps : int, optional
233+
Number of steps in Gaussian Random Walk (steps > 0). Only needed if size is
234+
used to specify distribution
217235
"""
218236

219237
rv_op = gaussianrandomwalk
220238

221-
def __new__(cls, name, mu=0.0, sigma=1.0, init=None, steps=None, **kwargs):
222-
if init is not None:
223-
check_dist_not_registered(init)
224-
return super().__new__(cls, name, mu, sigma, init, steps, **kwargs)
239+
def __new__(cls, *args, steps=None, **kwargs):
240+
steps = get_steps(
241+
steps=steps,
242+
shape=None, # Shape will be checked in `cls.dist`
243+
dims=kwargs.get("dims", None),
244+
observed=kwargs.get("observed", None),
245+
step_shape_offset=1,
246+
)
247+
return super().__new__(cls, *args, steps=steps, **kwargs)
225248

226249
@classmethod
227250
def dist(
228-
cls, mu=0.0, sigma=1.0, init=None, steps=None, size=None, **kwargs
251+
cls, mu=0.0, sigma=1.0, *, init=None, steps=None, size=None, **kwargs
229252
) -> at.TensorVariable:
230253

231254
mu = at.as_tensor_variable(floatX(mu))
232255
sigma = at.as_tensor_variable(floatX(sigma))
233256

234-
steps = get_steps_from_shape(steps, kwargs.get("shape", None), step_shape_offset=1)
257+
steps = get_steps(
258+
steps=steps,
259+
shape=kwargs.get("shape", None),
260+
step_shape_offset=1,
261+
)
262+
if steps is None:
263+
raise ValueError("Must specify steps or shape parameter")
235264
steps = at.as_tensor_variable(intX(steps))
236265

237266
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
@@ -245,6 +274,7 @@ def dist(
245274
and init.owner.op.ndim_supp == 0
246275
):
247276
raise TypeError("init must be a univariate distribution variable")
277+
check_dist_not_registered(init)
248278

249279
# Ignores logprob of init var because that's accounted for in the logp method
250280
init = ignore_logprob(init)
@@ -340,6 +370,9 @@ class AR(SymbolicDistribution):
340370
ar_order: int, optional
341371
Order of the AR process. Inferred from length of the last dimension of rho, if
342372
possible. ar_order = rho.shape[-1] if constant else rho.shape[-1] - 1
373+
steps : int, optional
374+
Number of steps in AR process (steps > 0). Only needed if size is used to
375+
specify distribution
343376
344377
Notes
345378
-----
@@ -360,6 +393,15 @@ class AR(SymbolicDistribution):
360393
361394
"""
362395

396+
def __new__(cls, *args, steps=None, **kwargs):
397+
steps = get_steps(
398+
steps=steps,
399+
shape=None, # Shape will be checked in `cls.dist`
400+
dims=kwargs.get("dims", None),
401+
observed=kwargs.get("observed", None),
402+
)
403+
return super().__new__(cls, *args, steps=steps, **kwargs)
404+
363405
@classmethod
364406
def dist(
365407
cls,
@@ -384,7 +426,9 @@ def dist(
384426
)
385427
init_dist = kwargs["init"]
386428

387-
steps = get_steps_from_shape(steps, kwargs.get("shape", None))
429+
steps = get_steps(steps=steps, shape=kwargs.get("shape", None))
430+
if steps is None:
431+
raise ValueError("Must specify steps or shape parameter")
388432
steps = at.as_tensor_variable(intX(steps), ndim=0)
389433

390434
if ar_order is None:

pymc/tests/test_distributions_timeseries.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,84 @@
1616
import pytest
1717
import scipy.stats
1818

19+
from aesara.tensor import TensorVariable
20+
1921
import pymc as pm
2022

2123
from pymc.aesaraf import floatX
2224
from pymc.distributions.continuous import Flat, HalfNormal, Normal
2325
from pymc.distributions.discrete import Constant
2426
from pymc.distributions.logprob import logp
2527
from pymc.distributions.multivariate import Dirichlet
26-
from pymc.distributions.timeseries import AR, GARCH11, EulerMaruyama, GaussianRandomWalk
28+
from pymc.distributions.timeseries import (
29+
AR,
30+
GARCH11,
31+
EulerMaruyama,
32+
GaussianRandomWalk,
33+
get_steps,
34+
)
2735
from pymc.model import Model
2836
from pymc.sampling import draw, sample, sample_posterior_predictive
2937
from pymc.tests.helpers import select_by_precision
3038
from pymc.tests.test_distributions_moments import assert_moment_is_expected
3139
from pymc.tests.test_distributions_random import BaseTestDistributionRandom
3240

3341

42+
@pytest.mark.parametrize(
43+
"steps, shape, step_shape_offset, expected_steps, consistent",
44+
[
45+
(10, None, 0, 10, True),
46+
(10, None, 1, 10, True),
47+
(None, (10,), 0, 10, True),
48+
(None, (10,), 1, 9, True),
49+
(None, (10, 5), 0, 5, True),
50+
(None, (10, ...), 0, None, True),
51+
(None, None, 0, None, True),
52+
(10, (10,), 0, 10, True),
53+
(10, (11,), 1, 10, True),
54+
(10, (5, ...), 1, 10, True),
55+
(10, (5, 5), 0, 5, False),
56+
(10, (5, 10), 1, 9, False),
57+
],
58+
)
59+
@pytest.mark.parametrize("info_source", ("shape", "dims", "observed"))
60+
def test_get_steps(info_source, steps, shape, step_shape_offset, expected_steps, consistent):
61+
if info_source == "shape":
62+
inferred_steps = get_steps(steps=steps, shape=shape, step_shape_offset=step_shape_offset)
63+
64+
elif info_source == "dims":
65+
if shape is None:
66+
dims = None
67+
coords = {}
68+
else:
69+
dims = tuple(str(i) if shape is not ... else ... for i, shape in enumerate(shape))
70+
coords = {str(i): range(shape) for i, shape in enumerate(shape) if shape is not ...}
71+
with Model(coords=coords):
72+
inferred_steps = get_steps(steps=steps, dims=dims, step_shape_offset=step_shape_offset)
73+
74+
elif info_source == "observed":
75+
if shape is None:
76+
observed = None
77+
else:
78+
if ... in shape:
79+
# There is no equivalent to implied dims in observed
80+
return
81+
observed = np.zeros(shape)
82+
inferred_steps = get_steps(
83+
steps=steps, observed=observed, step_shape_offset=step_shape_offset
84+
)
85+
86+
if not isinstance(inferred_steps, TensorVariable):
87+
assert inferred_steps == expected_steps
88+
else:
89+
if consistent:
90+
assert inferred_steps.eval() == expected_steps
91+
else:
92+
assert inferred_steps.owner.inputs[0].eval() == expected_steps
93+
with pytest.raises(AssertionError, match="Steps do not match"):
94+
inferred_steps.eval()
95+
96+
3497
class TestGaussianRandomWalk:
3598
class TestGaussianRandomWalkRandom(BaseTestDistributionRandom):
3699
# Override default size for test class
@@ -127,6 +190,18 @@ def test_inconsistent_steps_and_shape(self):
127190
with pytest.raises(AssertionError, match="Steps do not match last shape dimension"):
128191
x = GaussianRandomWalk.dist(steps=12, shape=45)
129192

193+
def test_inferred_steps_from_dims(self):
194+
with pm.Model(coords={"batch": range(5), "steps": range(20)}):
195+
x = GaussianRandomWalk("x", dims=("batch", "steps"))
196+
steps = x.owner.inputs[-1]
197+
assert steps.eval() == 19
198+
199+
def test_inferred_steps_from_observed(self):
200+
with pm.Model():
201+
x = GaussianRandomWalk("x", observed=np.zeros(10))
202+
steps = x.owner.inputs[-1]
203+
assert steps.eval() == 9
204+
130205
@pytest.mark.parametrize(
131206
"init",
132207
[

0 commit comments

Comments
 (0)