Skip to content

Commit dbeb801

Browse files
larryshamalamalucianopaz
authored andcommitted
Use Aeppl for implementing GaussianRandomWalk
Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: lucianopaz <[email protected]>
1 parent 906fcdc commit dbeb801

File tree

4 files changed

+192
-197
lines changed

4 files changed

+192
-197
lines changed

pymc/aesaraf.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@
3333
import scipy.sparse as sps
3434

3535
from aeppl.abstract import MeasurableVariable
36-
from aeppl.logprob import CheckParameterValue
36+
from aeppl.logprob import CheckParameterValue, _logprob, logprob
37+
from aeppl.tensor import MeasurableJoin
3738
from aesara import config, scalar
3839
from aesara.compile.mode import Mode, get_mode
3940
from aesara.gradient import grad
40-
from aesara.graph import local_optimizer
41+
from aesara.graph import local_optimizer, optimize_graph
4142
from aesara.graph.basic import (
4243
Apply,
4344
Constant,
@@ -52,6 +53,7 @@
5253
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
5354
from aesara.scalar.basic import Cast
5455
from aesara.tensor.basic import _as_tensor_variable
56+
from aesara.tensor.basic_opt import topo_constant_folding
5557
from aesara.tensor.elemwise import Elemwise
5658
from aesara.tensor.random.op import RandomVariable
5759
from aesara.tensor.random.var import (
@@ -875,6 +877,52 @@ def largest_common_dtype(tensors):
875877
return np.stack([np.ones((), dtype=dtype) for dtype in dtypes]).dtype
876878

877879

880+
@_logprob.register(MeasurableJoin)
881+
def logprob_join_constant_shapes(op, values, axis, *base_vars, **kwargs):
882+
"""Compute the log-likelihood graph for a `Join`.
883+
884+
This overrides the implementation in Aeppl, to constant fold the shapes
885+
of the base vars so that RandomVariables do not show up in the logp graph
886+
"""
887+
(value,) = values
888+
889+
base_var_shapes = at.stack([base_var.shape[axis] for base_var in base_vars])
890+
891+
base_var_shapes = optimize_graph(
892+
base_var_shapes,
893+
custom_opt=topo_constant_folding,
894+
)
895+
896+
split_values = at.split(
897+
value,
898+
splits_size=[base_var_shape for base_var_shape in base_var_shapes],
899+
n_splits=len(base_vars),
900+
axis=axis,
901+
)
902+
903+
logps = [
904+
logprob(base_var, split_value)
905+
for base_var, split_value in zip(base_vars, split_values)
906+
]
907+
908+
if len(set(logp.ndim for logp in logps)) != 1:
909+
raise ValueError(
910+
"Joined logps have different number of dimensions, this can happen when "
911+
"joining univariate and multivariate distributions",
912+
)
913+
914+
base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim
915+
join_logprob = at.concatenate(
916+
[
917+
at.atleast_1d(logprob(base_var, split_value))
918+
for base_var, split_value in zip(base_vars, split_values)
919+
],
920+
axis=axis - base_vars_ndim_supp,
921+
)
922+
923+
return join_logprob
924+
925+
878926
@local_optimizer(tracks=[CheckParameterValue])
879927
def local_remove_check_parameter(fgraph, node):
880928
"""Rewrite that removes Aeppl's CheckParameterValue

pymc/distributions/timeseries.py

Lines changed: 65 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import warnings
1515

16-
from typing import Any, Optional, Tuple, Union
16+
from typing import Any, Optional, Union
1717

1818
import aesara
1919
import aesara.tensor as at
@@ -28,22 +28,15 @@
2828
from aesara.raise_op import Assert
2929
from aesara.tensor import TensorVariable
3030
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
31+
from aesara.tensor.extra_ops import CumOp
3132
from aesara.tensor.random.op import RandomVariable
32-
from aesara.tensor.random.utils import normalize_size_param
3333

3434
from pymc.aesaraf import change_rv_size, convert_observed_data, floatX, intX
3535
from pymc.distributions import distribution, multivariate
3636
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
37-
from pymc.distributions.dist_math import check_parameters
3837
from pymc.distributions.distribution import SymbolicDistribution, _moment, moment
3938
from pymc.distributions.logprob import ignore_logprob, logp
40-
from pymc.distributions.shape_utils import (
41-
Dims,
42-
Shape,
43-
convert_dims,
44-
rv_size_is_none,
45-
to_tuple,
46-
)
39+
from pymc.distributions.shape_utils import Dims, Shape, convert_dims, to_tuple
4740
from pymc.model import modelcontext
4841
from pymc.util import check_dist_not_registered
4942

@@ -114,110 +107,8 @@ def get_steps(
114107
return inferred_steps
115108

116109

117-
class GaussianRandomWalkRV(RandomVariable):
118-
"""
119-
GaussianRandomWalk Random Variable
120-
"""
121-
122-
name = "GaussianRandomWalk"
123-
ndim_supp = 1
124-
ndims_params = [0, 0, 0, 0]
125-
dtype = "floatX"
126-
_print_name = ("GaussianRandomWalk", "\\operatorname{GaussianRandomWalk}")
127-
128-
def make_node(self, rng, size, dtype, mu, sigma, init_dist, steps):
129-
steps = at.as_tensor_variable(steps)
130-
if not steps.ndim == 0 or not steps.dtype.startswith("int"):
131-
raise ValueError("steps must be an integer scalar (ndim=0).")
132-
133-
mu = at.as_tensor_variable(mu)
134-
sigma = at.as_tensor_variable(sigma)
135-
init_dist = at.as_tensor_variable(init_dist)
136-
137-
# Resize init distribution
138-
size = normalize_size_param(size)
139-
# If not explicit, size is determined by the shapes of mu, sigma, and init
140-
init_dist_size = (
141-
size if not rv_size_is_none(size) else at.broadcast_shape(mu, sigma, init_dist)
142-
)
143-
init_dist = change_rv_size(init_dist, init_dist_size)
144-
145-
return super().make_node(rng, size, dtype, mu, sigma, init_dist, steps)
146-
147-
def _supp_shape_from_params(self, dist_params, reop_param_idx=0, param_shapes=None):
148-
steps = dist_params[3]
149-
150-
return (steps + 1,)
151-
152-
@classmethod
153-
def rng_fn(
154-
cls,
155-
rng: np.random.RandomState,
156-
mu: Union[np.ndarray, float],
157-
sigma: Union[np.ndarray, float],
158-
init_dist: Union[np.ndarray, float],
159-
steps: int,
160-
size: Tuple[int],
161-
) -> np.ndarray:
162-
"""Gaussian Random Walk generator.
163-
164-
The init value is drawn from the Normal distribution with the same sigma as the
165-
innovations.
166-
167-
Notes
168-
-----
169-
Currently does not support custom init distribution
170-
171-
Parameters
172-
----------
173-
rng: np.random.RandomState
174-
Numpy random number generator
175-
mu: array_like of float
176-
Random walk mean
177-
sigma: array_like of float
178-
Standard deviation of innovation (sigma > 0)
179-
init_dist: array_like of float
180-
Initialization value for GaussianRandomWalk
181-
steps: int
182-
Length of random walk, must be greater than 1. Returned array will be of size+1 to
183-
account as first value is initial value
184-
size: tuple of int
185-
The number of Random Walk time series generated
186-
187-
Returns
188-
-------
189-
ndarray
190-
"""
191-
192-
if steps < 1:
193-
raise ValueError("Steps must be greater than 0")
194-
195-
# If size is None then the returned series should be (*implied_dims, 1+steps)
196-
if size is None:
197-
# broadcast parameters with each other to find implied dims
198-
bcast_shape = np.broadcast_shapes(
199-
np.asarray(mu).shape,
200-
np.asarray(sigma).shape,
201-
np.asarray(init_dist).shape,
202-
)
203-
dist_shape = (*bcast_shape, int(steps))
204-
205-
# If size is None then the returned series should be (*size, 1+steps)
206-
else:
207-
dist_shape = (*size, int(steps))
208-
209-
# Add one dimension to the right, so that mu and sigma broadcast safely along
210-
# the steps dimension
211-
innovations = rng.normal(loc=mu[..., None], scale=sigma[..., None], size=dist_shape)
212-
grw = np.concatenate([init_dist[..., None], innovations], axis=-1)
213-
return np.cumsum(grw, axis=-1)
214-
215-
216-
gaussianrandomwalk = GaussianRandomWalkRV()
217-
218-
219-
class GaussianRandomWalk(distribution.Continuous):
220-
r"""Random Walk with Normal innovations.
110+
class GaussianRandomWalk(SymbolicDistribution):
111+
r"""Random Walk with Normal innovations
221112
222113
Parameters
223114
----------
@@ -236,8 +127,6 @@ class GaussianRandomWalk(distribution.Continuous):
236127
provided.
237128
"""
238129

239-
rv_op = gaussianrandomwalk
240-
241130
def __new__(cls, *args, steps=None, **kwargs):
242131
steps = get_steps(
243132
steps=steps,
@@ -269,6 +158,8 @@ def dist(cls, mu=0.0, sigma=1.0, *, init_dist=None, steps=None, **kwargs) -> at.
269158
FutureWarning,
270159
)
271160
init_dist = kwargs.pop("init")
161+
if not steps.ndim == 0:
162+
raise ValueError("steps must be an integer scalar (ndim=0).")
272163

273164
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
274165
if init_dist is None:
@@ -293,39 +184,67 @@ def dist(cls, mu=0.0, sigma=1.0, *, init_dist=None, steps=None, **kwargs) -> at.
293184

294185
return super().dist([mu, sigma, init_dist, steps], **kwargs)
295186

296-
def moment(rv, size, mu, sigma, init_dist, steps):
297-
grw_moment = at.zeros_like(rv)
298-
grw_moment = at.set_subtensor(grw_moment[..., 0], moment(init_dist))
299-
# Add one dimension to the right, so that mu broadcasts safely along the steps
300-
# dimension
301-
grw_moment = at.set_subtensor(grw_moment[..., 1:], mu[..., None])
302-
return at.cumsum(grw_moment, axis=-1)
303-
304-
def logp(
305-
value: at.Variable,
306-
mu: at.Variable,
307-
sigma: at.Variable,
308-
init_dist: at.Variable,
309-
steps: at.Variable,
310-
) -> at.TensorVariable:
311-
"""Calculate log-probability of Gaussian Random Walk distribution at specified value."""
312-
313-
# Calculate initialization logp
314-
init_logp = logp(init_dist, value[..., 0])
315-
316-
# Make time series stationary around the mean value
317-
stationary_series = value[..., 1:] - value[..., :-1]
318-
# Add one dimension to the right, so that mu and sigma broadcast safely along
319-
# the steps dimension
320-
series_logp = logp(Normal.dist(mu[..., None], sigma[..., None]), stationary_series)
321-
322-
return check_parameters(
323-
init_logp + series_logp.sum(axis=-1),
324-
steps > 0,
325-
msg="steps > 0",
187+
@classmethod
188+
def ndim_supp(cls, *args):
189+
return 1
190+
191+
@classmethod
192+
def rv_op(cls, mu, sigma, init_dist, steps, size=None):
193+
# If not explicit, size is determined by the shapes of mu, sigma, and init
194+
if size is not None:
195+
# we have all the information regarding size from users or .dist()
196+
init_size = size
197+
else:
198+
# we infer size from parameters
199+
init_size = at.broadcast_shape(
200+
mu,
201+
sigma,
202+
init_dist,
203+
)
204+
205+
# TODO: extend for multivariate init
206+
init_dist = change_rv_size(init_dist, init_size)
207+
innovation_dist = Normal.dist(mu[..., None], sigma[..., None], size=(*init_size, steps))
208+
rv_out = at.cumsum(at.concatenate([init_dist[..., None], innovation_dist], axis=-1), axis=-1)
209+
210+
rv_out.tag.mu = mu
211+
rv_out.tag.sigma = sigma
212+
rv_out.tag.init_dist = init_dist
213+
rv_out.tag.innovation_dist = innovation_dist
214+
rv_out.tag.steps = steps
215+
rv_out.tag.is_grw = True # for moment dispatching
216+
217+
return rv_out
218+
219+
@classmethod
220+
def change_size(cls, rv, new_size, expand=False):
221+
if expand:
222+
new_size = at.concatenate([new_size, rv.shape[:-1]])
223+
224+
return cls.rv_op(
225+
mu=rv.tag.mu,
226+
sigma=rv.tag.sigma,
227+
init_dist=rv.tag.init_dist,
228+
steps=rv.tag.steps,
229+
size=new_size,
326230
)
327231

328232

233+
@_moment.register(CumOp)
234+
def moment_grw(op, rv, dist_params):
235+
"""
236+
This moment dispatch is currently only applicable for a GaussianRandomWalk.
237+
TODO: Encapsulate GRW graph in an OpFromGraph so that we can dispatch
238+
the moment directly on it
239+
"""
240+
if not getattr(rv.tag, "is_grw", False):
241+
raise NotImplementedError("Moment not implemented for `CumOp`")
242+
init_dist = rv.tag.init_dist
243+
innovation_dist = rv.tag.innovation_dist
244+
grw_moment = at.concatenate([moment(init_dist)[..., None], moment(innovation_dist)], axis=-1)
245+
return at.cumsum(grw_moment, axis=-1)
246+
247+
329248
class AutoRegressiveRV(OpFromGraph):
330249
"""A placeholder used to specify a log-likelihood for an AR sub-graph."""
331250

pymc/tests/test_distributions_moments.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Exponential,
2929
Flat,
3030
Gamma,
31+
GaussianRandomWalk,
3132
Geometric,
3233
Gumbel,
3334
HalfCauchy,
@@ -1506,3 +1507,43 @@ def test_dirichlet_multinomial_moment(a, n, size, expected):
15061507
with Model() as model:
15071508
DirichletMultinomial("x", n=n, a=a, size=size)
15081509
assert_moment_is_expected(model, expected)
1510+
1511+
1512+
even_numbers = np.insert(np.cumsum(np.tile(2, 10)), 0, 0) # for test_gaussianrandomwalk below
1513+
1514+
1515+
@pytest.mark.parametrize(
1516+
"mu, sigma, init_dist, steps, size, expected",
1517+
[
1518+
(0, 3, StudentT.dist(5), 10, None, np.zeros(11)),
1519+
(Normal.dist(2, 3), Gamma.dist(1, 1), StudentT.dist(5), 10, None, even_numbers),
1520+
(
1521+
Normal.dist(2, 3, size=(3, 5)),
1522+
Gamma.dist(1, 1),
1523+
StudentT.dist(5),
1524+
10,
1525+
None,
1526+
np.broadcast_to(even_numbers, shape=(3, 5, 11)),
1527+
),
1528+
(
1529+
Normal.dist(2, 3, size=(3, 1)),
1530+
Gamma.dist(1, 1, size=(1, 5)),
1531+
StudentT.dist(5),
1532+
10,
1533+
(3, 5),
1534+
np.broadcast_to(even_numbers, shape=(3, 5, 11)),
1535+
),
1536+
(
1537+
Normal.dist(2, 3),
1538+
Gamma.dist(1, 1),
1539+
StudentT.dist(5),
1540+
10,
1541+
(3, 5),
1542+
np.broadcast_to(even_numbers, shape=(3, 5, 11)),
1543+
),
1544+
],
1545+
)
1546+
def test_gaussianrandomwalk(mu, sigma, init_dist, steps, size, expected):
1547+
with Model() as model:
1548+
GaussianRandomWalk("x", mu=mu, sigma=sigma, init_dist=init_dist, steps=steps, size=size)
1549+
assert_moment_is_expected(model, expected)

0 commit comments

Comments
 (0)