diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 6679972f62..8680528682 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -102,7 +102,6 @@ from pymc.distributions.simulator import Simulator from pymc.distributions.timeseries import ( AR, - AR1, GARCH11, GaussianRandomWalk, MvGaussianRandomWalk, @@ -169,7 +168,6 @@ "WishartBartlett", "LKJCholeskyCov", "LKJCorr", - "AR1", "AR", "AsymmetricLaplace", "GaussianRandomWalk", diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 8aaad5c106..a59965c857 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -76,6 +76,14 @@ def dist(cls, dist, lower, upper, **kwargs): check_dist_not_registered(dist) return super().dist([dist, lower, upper], **kwargs) + @classmethod + def num_rngs(cls, *args, **kwargs): + return 1 + + @classmethod + def ndim_supp(cls, *dist_params): + return 0 + @classmethod def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): @@ -96,24 +104,12 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): rv_out.tag.upper = upper if rngs is not None: - rv_out = cls.change_rngs(rv_out, rngs) + rv_out = cls._change_rngs(rv_out, rngs) return rv_out @classmethod - def ndim_supp(cls, *dist_params): - return 0 - - @classmethod - def change_size(cls, rv, new_size, expand=False): - dist = rv.tag.dist - lower = rv.tag.lower - upper = rv.tag.upper - new_dist = change_rv_size(dist, new_size, expand=expand) - return cls.rv_op(new_dist, lower, upper) - - @classmethod - def change_rngs(cls, rv, new_rngs): + def _change_rngs(cls, rv, new_rngs): (new_rng,) = new_rngs dist_node = rv.tag.dist.owner lower = rv.tag.lower @@ -123,8 +119,12 @@ def change_rngs(cls, rv, new_rngs): return cls.rv_op(new_dist, lower, upper) @classmethod - def graph_rvs(cls, rv): - return (rv.tag.dist,) + def change_size(cls, rv, new_size, expand=False): + dist = rv.tag.dist + lower = rv.tag.lower + upper = rv.tag.upper + new_dist = change_rv_size(dist, new_size, expand=expand) + return cls.rv_op(new_dist, lower, upper) @_moment.register(Clip) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index ac3541380b..138b3cd253 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -364,6 +364,40 @@ def dist( class SymbolicDistribution: + """Symbolic statistical distribution + + While traditional PyMC distributions are represented by a single RandomVariable + graph, Symbolic distributions correspond to a larger graph that contains one or + more RandomVariables and an arbitrary number of deterministic operations, which + represent their own kind of distribution. + + The graphs returned by symbolic distributions can be evaluated directly to + obtain valid draws and can further be parsed by Aeppl to derive the + corresponding logp at runtime. + + Check pymc.distributions.Censored for an example of a symbolic distribution. + + Symbolic distributions must implement the following classmethods: + cls.dist + Performs input validation and converts optional alternative parametrizations + to a canonical parametrization. It should call `super().dist()`, passing a + list with the default parameters as the first and only non keyword argument, + followed by other keyword arguments like size and rngs, and return the result + cls.num_rngs + Returns the number of rngs given the same arguments passed by the user when + calling the distribution + cls.ndim_supp + Returns the support of the symbolic distribution, given the default set of + parameters. This may not always be constant, for instance if the symbolic + distribution can be defined based on an arbitrary base distribution. + cls.rv_op + Returns a TensorVariable that represents the symbolic distribution + parametrized by a default set of parameters and a size and rngs arguments + cls.change_size + Returns an equivalent symbolic distribution with a different size. This is + analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s. + """ + def __new__( cls, name: str, @@ -379,36 +413,6 @@ def __new__( """Adds a TensorVariable corresponding to a PyMC symbolic distribution to the current model. - While traditional PyMC distributions are represented by a single RandomVariable - graph, Symbolic distributions correspond to a larger graph that contains one or - more RandomVariables and an arbitrary number of deterministic operations, which - represent their own kind of distribution. - - The graphs returned by symbolic distributions can be evaluated directly to - obtain valid draws and can further be parsed by Aeppl to derive the - corresponding logp at runtime. - - Check pymc.distributions.Censored for an example of a symbolic distribution. - - Symbolic distributions must implement the following classmethods: - cls.dist - Performs input validation and converts optional alternative parametrizations - to a canonical parametrization. It should call `super().dist()`, passing a - list with the default parameters as the first and only non keyword argument, - followed by other keyword arguments like size and rngs, and return the result - cls.rv_op - Returns a TensorVariable that represents the symbolic distribution - parametrized by a default set of parameters and a size and rngs arguments - cls.ndim_supp - Returns the support of the symbolic distribution, given the default - parameters. This may not always be constant, for instance if the symbolic - distribution can be defined based on an arbitrary base distribution. - cls.change_size - Returns an equivalent symbolic distribution with a different size. This is - analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s. - cls.graph_rvs - Returns base RVs in a symbolic distribution. - Parameters ---------- cls : type @@ -465,9 +469,9 @@ def __new__( raise TypeError(f"Name needs to be a string but got: {name}") if rngs is None: - # Create a temporary rv to obtain number of rngs needed - temp_graph = cls.dist(*args, rngs=None, **kwargs) - rngs = [model.next_rng() for _ in cls.graph_rvs(temp_graph)] + # Instead of passing individual RNG variables we could pass a RandomStream + # and let the classes create as many RNGs as they need + rngs = [model.next_rng() for _ in range(cls.num_rngs(*args, **kwargs))] elif not isinstance(rngs, (list, tuple)): rngs = [rngs] @@ -523,7 +527,6 @@ def dist( The inputs to the `RandomVariable` `Op`. shape : int, tuple, Variable, optional A tuple of sizes for each dimension of the new RV. - An Ellipsis (...) may be inserted in the last position to short-hand refer to all the dimensions that the RV would get if no shape/size/dims were passed at all. size : int, tuple, Variable, optional diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index b613f90bac..a1c6129ebe 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -205,6 +205,18 @@ def dist(cls, w, comp_dists, **kwargs): w = at.as_tensor_variable(w) return super().dist([w, *comp_dists], **kwargs) + @classmethod + def num_rngs(cls, w, comp_dists, **kwargs): + if not isinstance(comp_dists, (tuple, list)): + # comp_dists is a single component + comp_dists = [comp_dists] + return len(comp_dists) + 1 + + @classmethod + def ndim_supp(cls, weights, *components): + # We already checked that all components have the same support dimensionality + return components[0].owner.op.ndim_supp + @classmethod def rv_op(cls, weights, *components, size=None, rngs=None): # Update rngs if provided @@ -329,11 +341,6 @@ def _resize_components(cls, size, *components): return [change_rv_size(component, size) for component in components] - @classmethod - def ndim_supp(cls, weights, *components): - # We already checked that all components have the same support dimensionality - return components[0].owner.op.ndim_supp - @classmethod def change_size(cls, rv, new_size, expand=False): weights = rv.tag.weights @@ -355,14 +362,6 @@ def change_size(cls, rv, new_size, expand=False): return cls.rv_op(weights, *components, rngs=rngs, size=None) - @classmethod - def graph_rvs(cls, rv): - # We return rv, which is itself a pseudo RandomVariable, that contains a - # mix_indexes_ RV in its inner graph. We want super().dist() to generate - # (components + 1) rngs for us, and it will do so based on how many elements - # we return here - return (*rv.tag.components, rv) - @_get_measurable_outputs.register(MarginalMixtureRV) def _get_measurable_outputs_MarginalMixtureRV(op, node): diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 630b1cbeb9..7649d0069d 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -11,14 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings -from typing import Tuple, Union +from typing import Optional, Tuple, Union +import aesara import aesara.tensor as at import numpy as np +from aeppl.abstract import MeasurableVariable, _get_measurable_outputs +from aeppl.logprob import _logprob from aesara import scan +from aesara.compile.builders import OpFromGraph +from aesara.graph import FunctionGraph, optimize_graph +from aesara.graph.basic import Node from aesara.raise_op import Assert +from aesara.tensor import TensorVariable +from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.utils import normalize_size_param @@ -26,13 +35,12 @@ from pymc.distributions import distribution, multivariate from pymc.distributions.continuous import Flat, Normal, get_tau_sigma from pymc.distributions.dist_math import check_parameters -from pymc.distributions.distribution import moment +from pymc.distributions.distribution import SymbolicDistribution, _moment, moment from pymc.distributions.logprob import ignore_logprob, logp -from pymc.distributions.shape_utils import rv_size_is_none, to_tuple +from pymc.distributions.shape_utils import Shape, rv_size_is_none, to_tuple from pymc.util import check_dist_not_registered __all__ = [ - "AR1", "AR", "GaussianRandomWalk", "GARCH11", @@ -42,6 +50,53 @@ ] +def get_steps_from_shape( + steps: Optional[Union[int, np.ndarray, TensorVariable]], + shape: Optional[Shape], + step_shape_offset: int = 0, +): + """Extract number of steps from shape information + + Parameters + ---------- + steps: + User specified steps for timeseries distribution + shape: + User specified shape for timeseries distribution + step_shape_offset: + Difference between last shape dimension and number of steps in timeseries + distribution, defaults to 0 + + Raises + ------ + ValueError + If neither shape nor steps are provided + + Returns + ------- + steps + Steps, if specified directly by user, or inferred from the last dimension of + shape. When both steps and shape are provided, a symbolic Assert is added + to make sure they are consistent. + """ + steps_from_shape = None + if shape is not None: + shape = to_tuple(shape) + if shape[-1] is not ...: + steps_from_shape = shape[-1] - step_shape_offset + if steps is None: + if steps_from_shape is not None: + steps = steps_from_shape + else: + raise ValueError("Must specify steps or shape parameter") + elif steps_from_shape is not None: + # Assert that steps and shape are consistent + steps = Assert(msg="Steps do not match last shape dimension")( + steps, at.eq(steps, steps_from_shape) + ) + return steps + + class GaussianRandomWalkRV(RandomVariable): """ GaussianRandomWalk Random Variable @@ -176,25 +231,7 @@ def dist( mu = at.as_tensor_variable(floatX(mu)) sigma = at.as_tensor_variable(floatX(sigma)) - # Check if shape contains information about number of steps - steps_from_shape = None - shape = kwargs.get("shape", None) - if shape is not None: - shape = to_tuple(shape) - if shape[-1] is not ...: - steps_from_shape = shape[-1] - 1 - - if steps is None: - if steps_from_shape is not None: - steps = steps_from_shape - else: - raise ValueError("Must specify steps or shape parameter") - elif steps_from_shape is not None: - # Assert that steps and shape are consistent - steps = Assert(msg="Steps do not match last shape dimension")( - steps, at.eq(steps, steps_from_shape) - ) - + steps = get_steps_from_shape(steps, kwargs.get("shape", None), step_shape_offset=1) steps = at.as_tensor_variable(intX(steps)) # If no scalar distribution is passed then initialize with a Normal of same mu and sigma @@ -247,62 +284,34 @@ def logp( ) -class AR1(distribution.Continuous): - """ - Autoregressive process with 1 lag. +class AutoRegressiveRV(OpFromGraph): + """A placeholder used to specify a log-likelihood for an AR sub-graph.""" - Parameters - ---------- - k: tensor - effect of lagged value on current value - tau_e: tensor - precision for innovations - """ + default_output = 1 + ar_order: int + constant_term: bool - def __init__(self, k, tau_e, *args, **kwargs): + def __init__(self, *args, ar_order, constant_term, **kwargs): + self.ar_order = ar_order + self.constant_term = constant_term super().__init__(*args, **kwargs) - self.k = k = at.as_tensor_variable(k) - self.tau_e = tau_e = at.as_tensor_variable(tau_e) - self.tau = tau_e * (1 - k**2) - self.mode = at.as_tensor_variable(0.0) - def logp(self, x): - """ - Calculate log-probability of AR1 distribution at specified value. + def update(self, node: Node): + """Return the update mapping for the noise RV.""" + # Since noise is a shared variable it shows up as the last node input + return {node.inputs[-1]: node.outputs[0]} - Parameters - ---------- - x: numeric - Value for which log-probability is calculated. - Returns - ------- - TensorVariable - """ - k = self.k - tau_e = self.tau_e # innovation precision - tau = tau_e * (1 - k**2) # ar1 precision - - x_im1 = x[:-1] - x_i = x[1:] - boundary = Normal.dist(0.0, tau=tau).logp - - innov_like = Normal.dist(k * x_im1, tau=tau_e).logp(x_i) - return boundary(x[0]) + at.sum(innov_like) - - -class AR(distribution.Continuous): - r""" - Autoregressive process with p lags. +class AR(SymbolicDistribution): + r"""Autoregressive process with p lags. .. math:: x_t = \rho_0 + \rho_1 x_{t-1} + \ldots + \rho_p x_{t-p} + \epsilon_t, \epsilon_t \sim N(0,\sigma^2) - The innovation can be parameterized either in terms of precision - or standard deviation. The link between the two parametrizations is - given by + The innovation can be parameterized either in terms of precision or standard + deviation. The link between the two parametrizations is given by .. math:: @@ -310,79 +319,266 @@ class AR(distribution.Continuous): Parameters ---------- - rho: tensor - Tensor of autoregressive coefficients. The first dimension is the p lag. - sigma: float - Standard deviation of innovation (sigma > 0). (only required if tau is not specified) - tau: float - Precision of innovation (tau > 0). (only required if sigma is not specified) - constant: bool (optional, default = False) - Whether to include a constant. - init: distribution - distribution for initial values (Defaults to Flat()) - """ + rho: tensor_like of float + Tensor of autoregressive coefficients. The n-th entry in the last dimension is + the coefficient for the n-th lag. + sigma: tensor_like of float, optional + Standard deviation of innovation (sigma > 0). Defaults to 1. Only required if + tau is not specified. + tau: tensor_like of float + Precision of innovation (tau > 0). + constant: bool, optional + Whether the first element of rho should be used as a constant term in the AR + process. Defaults to False + init_dist: unnamed distribution, optional + Scalar or vector distribution for initial values. Defaults to Normal(0, sigma). + Distribution should be created via the `.dist()` API, and have dimension + (*size, ar_order). If not, it will be automatically resized. + + .. warning:: init_dist will be cloned, rendering it independent of the one passed as input. + + ar_order: int, optional + Order of the AR process. Inferred from length of the last dimension of rho, if + possible. ar_order = rho.shape[-1] if constant else rho.shape[-1] - 1 - def __init__(self, rho, sigma=None, tau=None, constant=False, init=None, *args, **kwargs): - super().__init__(*args, **kwargs) - tau, sigma = get_tau_sigma(tau=tau, sigma=sigma) - self.sigma = at.as_tensor_variable(sigma) - self.tau = at.as_tensor_variable(tau) + Notes + ----- + The init distribution will be cloned, rendering it distinct from the one passed as + input. - self.mean = at.as_tensor_variable(0.0) + Examples + -------- + .. code-block:: python - if isinstance(rho, list): - p = len(rho) - else: - try: - shape_ = rho.shape.tag.test_value - except AttributeError: - shape_ = rho.shape + # Create an AR of order 3, with a constant term + with pm.Model() as AR3: + # The first coefficient will be the constant term + coefs = pm.Normal("coefs", 0, size=4) + # We need one init variable for each lag, hence size=3 + init = pm.Normal.dist(5, size=3) + ar3 = pm.AR("ar3", coefs, sigma=1.0, init_dist=init, constant=True, steps=500) - if hasattr(shape_, "size") and shape_.size == 0: - p = 1 - else: - p = shape_[0] + """ + + @classmethod + def dist( + cls, + rho, + sigma=None, + tau=None, + *, + init_dist=None, + steps=None, + constant=False, + ar_order=None, + **kwargs, + ): + _, sigma = get_tau_sigma(tau=tau, sigma=sigma) + sigma = at.as_tensor_variable(floatX(sigma)) + rhos = at.atleast_1d(at.as_tensor_variable(floatX(rho))) - if constant: - self.p = p - 1 + if "init" in kwargs: + warnings.warn( + "init parameter is now called init_dist. Using init will raise an error in a future release.", + FutureWarning, + ) + init_dist = kwargs["init"] + + steps = get_steps_from_shape(steps, kwargs.get("shape", None)) + steps = at.as_tensor_variable(intX(steps), ndim=0) + + if ar_order is None: + # If ar_order is not specified we do constant folding on the shape of rhos + # to retrieve it. For example, this will detect that + # Normal(size=(5, 3)).shape[-1] == 3, which is not known by Aesara before. + shape_fg = FunctionGraph( + outputs=[rhos.shape[-1]], + features=[ShapeFeature()], + clone=True, + ) + (folded_shape,) = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs + folded_shape = getattr(folded_shape, "data", None) + if folded_shape is None: + raise ValueError( + "Could not infer ar_order from last dimension of rho. Pass it " + "explictily or make sure rho have a static shape" + ) + ar_order = int(folded_shape) - int(constant) + if ar_order < 1: + raise ValueError( + "Inferred ar_order is smaller than 1. Increase the last dimension " + "of rho or remove constant_term" + ) + + if init_dist is not None: + if not isinstance(init_dist, TensorVariable) or not isinstance( + init_dist.owner.op, RandomVariable + ): + raise ValueError( + f"Init dist must be a distribution created via the `.dist()` API, " + f"got {type(init_dist)}" + ) + check_dist_not_registered(init_dist) + if init_dist.owner.op.ndim_supp > 1: + raise ValueError( + "Init distribution must have a scalar or vector support dimension, ", + f"got ndim_supp={init_dist.owner.op.ndim_supp}.", + ) else: - self.p = p + # Sigma must broadcast with ar_order + init_dist = Normal.dist(sigma=at.shape_padright(sigma), size=(*sigma.shape, ar_order)) - self.constant = constant - self.rho = rho = at.as_tensor_variable(rho) - self.init = init or Flat.dist() + # Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term + init_dist = ignore_logprob(init_dist) - def logp(self, value): - """ - Calculate log-probability of AR distribution at specified value. + return super().dist([rhos, sigma, init_dist, steps, ar_order, constant], **kwargs) - Parameters - ---------- - value: numeric - Value for which log-probability is calculated. + @classmethod + def num_rngs(cls, *args, **kwargs): + return 2 - Returns - ------- - TensorVariable - """ - if self.constant: - x = at.add( - *(self.rho[i + 1] * value[self.p - (i + 1) : -(i + 1)] for i in range(self.p)) - ) - eps = value[self.p :] - self.rho[0] - x + @classmethod + def ndim_supp(cls, *args): + return 1 + + @classmethod + def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None, rngs=None): + + if rngs is None: + rngs = [ + aesara.shared(np.random.default_rng(seed)) + for seed in np.random.SeedSequence().spawn(2) + ] + (init_dist_rng, noise_rng) = rngs + # Re-seed init_dist + if init_dist.owner.inputs[0] is not init_dist_rng: + _, *inputs = init_dist.owner.inputs + init_dist = init_dist.owner.op.make_node(init_dist_rng, *inputs).default_output() + + # Init dist should have shape (*size, ar_order) + if size is not None: + batch_size = size + else: + # In this case the size of the init_dist depends on the parameters shape + # The last dimension of rho and init_dist does not matter + batch_size = at.broadcast_shape(sigma, rhos[..., 0], init_dist[..., 0]) + if init_dist.owner.op.ndim_supp == 0: + init_dist_size = (*batch_size, ar_order) else: - if self.p == 1: - x = self.rho * value[:-1] + # In this case the support dimension must cover for ar_order + init_dist_size = batch_size + init_dist = change_rv_size(init_dist, init_dist_size) + + # Create OpFromGraph representing random draws form AR process + # Variables with underscore suffix are dummy inputs into the OpFromGraph + init_ = init_dist.type() + rhos_ = rhos.type() + sigma_ = sigma.type() + steps_ = steps.type() + + rhos_bcast_shape_ = init_.shape + if constant_term: + # In this case init shape is one unit smaller than rhos in the last dimension + rhos_bcast_shape_ = (*rhos_bcast_shape_[:-1], rhos_bcast_shape_[-1] + 1) + rhos_bcast_ = at.broadcast_to(rhos_, rhos_bcast_shape_) + + def step(*args): + *prev_xs, reversed_rhos, sigma, rng = args + if constant_term: + mu = reversed_rhos[-1] + at.sum(prev_xs * reversed_rhos[:-1], axis=0) else: - x = at.add( - *(self.rho[i] * value[self.p - (i + 1) : -(i + 1)] for i in range(self.p)) - ) - eps = value[self.p :] - x + mu = at.sum(prev_xs * reversed_rhos, axis=0) + next_rng, new_x = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs + return new_x, {rng: next_rng} + + # We transpose inputs as scan iterates over first dimension + innov_, innov_updates_ = aesara.scan( + fn=step, + outputs_info=[{"initial": init_.T, "taps": range(-ar_order, 0)}], + non_sequences=[rhos_bcast_.T[::-1], sigma_.T, noise_rng], + n_steps=at.max((0, steps_ - ar_order)), + strict=True, + ) + (noise_next_rng,) = tuple(innov_updates_.values()) + ar_ = at.concatenate([init_, innov_.T], axis=-1) + + ar_op = AutoRegressiveRV( + inputs=[rhos_, sigma_, init_, steps_], + outputs=[noise_next_rng, ar_], + ar_order=ar_order, + constant_term=constant_term, + inline=True, + ) + + ar = ar_op(rhos, sigma, init_dist, steps) + return ar - innov_like = Normal.dist(mu=0.0, tau=self.tau).logp(eps) - init_like = self.init.logp(value[: self.p]) + @classmethod + def change_size(cls, rv, new_size, expand=False): + + if expand: + old_size = rv.shape[:-1] + new_size = at.concatenate([new_size, old_size]) + + init_dist_rng = rv.owner.inputs[2].owner.inputs[0] + noise_rng = rv.owner.inputs[-1] + + op = rv.owner.op + return cls.rv_op( + *rv.owner.inputs, + ar_order=op.ar_order, + constant_term=op.constant_term, + size=new_size, + rngs=(init_dist_rng, noise_rng), + ) - return at.sum(innov_like) + at.sum(init_like) + +MeasurableVariable.register(AutoRegressiveRV) + + +@_get_measurable_outputs.register(AutoRegressiveRV) +def _get_measurable_outputs_ar(op, node): + # This tells Aeppl that the second output is the measurable one + return [node.outputs[1]] + + +@_logprob.register(AutoRegressiveRV) +def ar_logp(op, values, rhos, sigma, init_dist, steps, noise_rng, **kwargs): + (value,) = values + + ar_order = op.ar_order + constant_term = op.constant_term + + # Convolve rhos with values + if constant_term: + expectation = at.add( + rhos[..., 0, None], + *( + rhos[..., i + 1, None] * value[..., ar_order - (i + 1) : -(i + 1)] + for i in range(ar_order) + ), + ) + else: + expectation = at.add( + *( + rhos[..., i, None] * value[..., ar_order - (i + 1) : -(i + 1)] + for i in range(ar_order) + ) + ) + # Compute and collapse logp across time dimension + innov_logp = at.sum( + logp(Normal.dist(0, sigma[..., None]), value[..., ar_order:] - expectation), axis=-1 + ) + init_logp = logp(init_dist, value[..., :ar_order]) + if init_dist.owner.op.ndim_supp == 0: + init_logp = at.sum(init_logp, axis=-1) + return init_logp + innov_logp + + +@_moment.register(AutoRegressiveRV) +def ar_moment(op, rv, rhos, sigma, init_dist, steps, noise_rng): + # Use last entry of init_dist moment as the moment for the whole AR + return at.full_like(rv, moment(init_dist)[..., -1, None]) class GARCH11(distribution.Continuous): diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index cb485a086f..a8d8ea4dd4 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -60,7 +60,6 @@ def polyagamma_cdf(*args, **kwargs): from pymc.aesaraf import floatX, intX from pymc.distributions import ( - AR1, CAR, AsymmetricLaplace, Bernoulli, @@ -834,14 +833,6 @@ def mvt_logpdf(value, nu, Sigma, mu=0): return logp_mvt.sum() -def AR1_logpdf(value, k, tau_e): - tau = tau_e * (1 - k**2) - return ( - sp.norm(loc=0, scale=1 / np.sqrt(tau)).logpdf(value[0]) - + sp.norm(loc=k * value[:-1], scale=1 / np.sqrt(tau_e)).logpdf(value[1:]).sum() - ) - - def invlogit(x, eps=sys.float_info.epsilon): return (1.0 - 2.0 * eps) / (1.0 + np.exp(-x)) + eps @@ -2078,11 +2069,6 @@ def test_mvt(self, n): extra_args={"size": 2}, ) - @pytest.mark.parametrize("n", [2, 3, 4]) - @pytest.mark.xfail(reason="Distribution not refactored yet") - def test_AR1(self, n): - check_logp(AR1, Vector(R, n), {"k": Unit, "tau_e": Rplus}, AR1_logpdf) - @pytest.mark.parametrize("n", [2, 3]) def test_wishart(self, n): check_logp( diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index f2ac38bd16..230989000c 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -100,8 +100,6 @@ def test_all_distributions_have_moments(): # Distributions that have not been refactored for V4 yet not_implemented = { - dist_module.timeseries.AR, - dist_module.timeseries.AR1, dist_module.timeseries.GARCH11, dist_module.timeseries.MvGaussianRandomWalk, dist_module.timeseries.MvStudentTRandomWalk, diff --git a/pymc/tests/test_distributions_timeseries.py b/pymc/tests/test_distributions_timeseries.py index c66e197553..d2fe7275d2 100644 --- a/pymc/tests/test_distributions_timeseries.py +++ b/pymc/tests/test_distributions_timeseries.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import aesara import numpy as np import pytest import scipy.stats @@ -18,16 +19,13 @@ import pymc as pm from pymc.aesaraf import floatX -from pymc.distributions.continuous import Flat, Normal -from pymc.distributions.timeseries import ( - AR, - AR1, - GARCH11, - EulerMaruyama, - GaussianRandomWalk, -) +from pymc.distributions.continuous import Flat, HalfNormal, Normal +from pymc.distributions.discrete import Constant +from pymc.distributions.logprob import logp +from pymc.distributions.multivariate import Dirichlet +from pymc.distributions.timeseries import AR, GARCH11, EulerMaruyama, GaussianRandomWalk from pymc.model import Model -from pymc.sampling import sample, sample_posterior_predictive +from pymc.sampling import draw, sample, sample_posterior_predictive from pymc.tests.helpers import select_by_precision from pymc.tests.test_distributions_moments import assert_moment_is_expected from pymc.tests.test_distributions_random import BaseTestDistributionRandom @@ -166,60 +164,198 @@ def test_moment(self, mu, sigma, init, steps, size, expected): assert_moment_is_expected(model, expected) -@pytest.mark.xfail(reason="Timeseries not refactored") -def test_AR(): - # AR1 - data = np.array([0.3, 1, 2, 3, 4]) - phi = np.array([0.99]) - with Model() as t: - y = AR("y", phi, sigma=1, shape=len(data)) - z = Normal("z", mu=phi * data[:-1], sigma=1, shape=len(data) - 1) - ar_like = t["y"].logp({"z": data[1:], "y": data}) - reg_like = t["z"].logp({"z": data[1:], "y": data}) - np.testing.assert_allclose(ar_like, reg_like) +class TestAR: + def test_order1_logp(self): + data = np.array([0.3, 1, 2, 3, 4]) + phi = np.array([0.99]) + with Model() as t: + y = AR("y", phi, sigma=1, init_dist=Flat.dist(), shape=len(data)) + z = Normal("z", mu=phi * data[:-1], sigma=1, shape=len(data) - 1) + ar_like = t.compile_logp(y)({"y": data}) + reg_like = t.compile_logp(z)({"z": data[1:]}) + np.testing.assert_allclose(ar_like, reg_like) + + with Model() as t_constant: + y = AR( + "y", + np.hstack((0.3, phi)), + sigma=1, + init_dist=Flat.dist(), + shape=len(data), + constant=True, + ) + z = Normal("z", mu=0.3 + phi * data[:-1], sigma=1, shape=len(data) - 1) + ar_like = t_constant.compile_logp(y)({"y": data}) + reg_like = t_constant.compile_logp(z)({"z": data[1:]}) + np.testing.assert_allclose(ar_like, reg_like) + + def test_order2_logp(self): + data = np.array([0.3, 1, 2, 3, 4]) + phi = np.array([0.84, 0.10]) + with Model() as t: + y = AR("y", phi, sigma=1, init_dist=Flat.dist(), shape=len(data)) + z = Normal( + "z", mu=phi[0] * data[1:-1] + phi[1] * data[:-2], sigma=1, shape=len(data) - 2 + ) + ar_like = t.compile_logp(y)({"y": data}) + reg_like = t.compile_logp(z)({"z": data[2:]}) + np.testing.assert_allclose(ar_like, reg_like) + + @pytest.mark.parametrize("constant", (False, True)) + def test_batched_size(self, constant): + ar_order, steps, batch_size = 3, 100, 5 + beta_tp = np.random.randn(batch_size, ar_order + int(constant)) + y_tp = np.random.randn(batch_size, steps) + with Model() as t0: + y = AR("y", beta_tp, shape=(batch_size, steps), initval=y_tp, constant=constant) + with Model() as t1: + for i in range(batch_size): + AR(f"y_{i}", beta_tp[i], sigma=1.0, shape=steps, initval=y_tp[i], constant=constant) + + assert y.owner.op.ar_order == ar_order + + np.testing.assert_allclose( + t0.compile_logp()(t0.initial_point()), + t1.compile_logp()(t1.initial_point()), + ) - # AR1 and AR(1) - with Model() as t: - rho = Normal("rho", 0.0, 1.0) - y1 = AR1("y1", rho, 1.0, observed=data) - y2 = AR("y2", rho, 1.0, init=Normal.dist(0, 1), observed=data) - initial_point = t.initial_point() - np.testing.assert_allclose(y1.logp(initial_point), y2.logp(initial_point)) + y_eval = draw(y, draws=2) + assert y_eval[0].shape == (batch_size, steps) + assert not np.any(np.isclose(y_eval[0], y_eval[1])) + + def test_batched_rhos(self): + ar_order, steps, batch_size = 3, 100, 5 + beta_tp = np.random.randn(batch_size, ar_order) + y_tp = np.random.randn(batch_size, steps) + with Model() as t0: + beta = Normal("beta", 0.0, 1.0, shape=(batch_size, ar_order), initval=beta_tp) + AR("y", beta, sigma=1.0, shape=(batch_size, steps), initval=y_tp) + with Model() as t1: + beta = Normal("beta", 0.0, 1.0, shape=(batch_size, ar_order), initval=beta_tp) + for i in range(batch_size): + AR(f"y_{i}", beta[i], sigma=1.0, shape=steps, initval=y_tp[i]) + + np.testing.assert_allclose( + t0.compile_logp()(t0.initial_point()), + t1.compile_logp()(t1.initial_point()), + ) - # AR1 + constant - with Model() as t: - y = AR("y", np.hstack((0.3, phi)), sigma=1, shape=len(data), constant=True) - z = Normal("z", mu=0.3 + phi * data[:-1], sigma=1, shape=len(data) - 1) - ar_like = t["y"].logp({"z": data[1:], "y": data}) - reg_like = t["z"].logp({"z": data[1:], "y": data}) - np.testing.assert_allclose(ar_like, reg_like) - - # AR2 - phi = np.array([0.84, 0.10]) - with Model() as t: - y = AR("y", phi, sigma=1, shape=len(data)) - z = Normal("z", mu=phi[0] * data[1:-1] + phi[1] * data[:-2], sigma=1, shape=len(data) - 2) - ar_like = t["y"].logp({"z": data[2:], "y": data}) - reg_like = t["z"].logp({"z": data[2:], "y": data}) - np.testing.assert_allclose(ar_like, reg_like) + beta_tp[1] = 0 # Should always be close to zero + y_eval = t0["y"].eval({t0["beta"]: beta_tp}) + assert y_eval.shape == (batch_size, steps) + assert np.all(abs(y_eval[1]) < 5) + + def test_batched_sigma(self): + ar_order, steps, batch_size = 4, 100, (7, 5) + # AR order cannot be inferred from beta_tp because it is not fixed. + # We specify it manually below + beta_tp = aesara.shared(np.random.randn(ar_order)) + sigma_tp = np.abs(np.random.randn(*batch_size)) + y_tp = np.random.randn(*batch_size, steps) + with Model() as t0: + sigma = HalfNormal("sigma", 1.0, shape=batch_size, initval=sigma_tp) + AR( + "y", + beta_tp, + sigma=sigma, + size=batch_size, + steps=steps, + initval=y_tp, + ar_order=ar_order, + ) + with Model() as t1: + sigma = HalfNormal("beta", 1.0, shape=batch_size, initval=sigma_tp) + for i in range(batch_size[0]): + for j in range(batch_size[1]): + AR( + f"y_{i}{j}", + beta_tp, + sigma=sigma[i][j], + shape=steps, + initval=y_tp[i, j], + ar_order=ar_order, + ) + + # Check logp shape + sigma_logp, y_logp = t0.compile_logp(sum=False)(t0.initial_point()) + assert tuple(y_logp.shape) == batch_size + + np.testing.assert_allclose( + sigma_logp.sum() + y_logp.sum(), + t1.compile_logp()(t1.initial_point()), + ) + beta_tp.set_value(np.zeros((ar_order,))) # Should always be close to zero + sigma_tp = np.full(batch_size, [0.01, 0.1, 1, 10, 100]) + y_eval = t0["y"].eval({t0["sigma"]: sigma_tp}) + assert y_eval.shape == (*batch_size, steps) + assert np.allclose(y_eval.std(axis=(0, 2)), [0.01, 0.1, 1, 10, 100], rtol=0.1) + + def test_batched_init_dist(self): + ar_order, steps, batch_size = 3, 100, 5 + beta_tp = aesara.shared(np.random.randn(ar_order), shape=(3,)) + y_tp = np.random.randn(batch_size, steps) + with Model() as t0: + init_dist = Normal.dist(0.0, 0.01, size=(batch_size, ar_order)) + AR("y", beta_tp, sigma=0.01, init_dist=init_dist, steps=steps, initval=y_tp) + with Model() as t1: + for i in range(batch_size): + AR(f"y_{i}", beta_tp, sigma=0.01, shape=steps, initval=y_tp[i]) + + np.testing.assert_allclose( + t0.compile_logp()(t0.initial_point()), + t1.compile_logp()(t1.initial_point()), + ) -@pytest.mark.xfail(reason="Timeseries not refactored") -def test_AR_nd(): - # AR2 multidimensional - p, T, n = 3, 100, 5 - beta_tp = np.random.randn(p, n) - y_tp = np.random.randn(T, n) - with Model() as t0: - beta = Normal("beta", 0.0, 1.0, shape=(p, n), initval=beta_tp) - AR("y", beta, sigma=1.0, shape=(T, n), initval=y_tp) - - with Model() as t1: - beta = Normal("beta", 0.0, 1.0, shape=(p, n), initval=beta_tp) - for i in range(n): - AR("y_%d" % i, beta[:, i], sigma=1.0, shape=T, initval=y_tp[:, i]) - - np.testing.assert_allclose(t0.logp(t0.initial_point()), t1.logp(t1.initial_point())) + # Next values should keep close to previous ones + beta_tp.set_value(np.full((ar_order,), 1 / ar_order)) + # Init dist is cloned when creating the AR, so the original variable is not + # part of the AR graph. We retrieve the one actually used manually + init_dist = t0["y"].owner.inputs[2] + init_dist_tp = np.full((batch_size, ar_order), (np.arange(batch_size) * 100)[:, None]) + y_eval = t0["y"].eval({init_dist: init_dist_tp}) + assert y_eval.shape == (batch_size, steps) + assert np.allclose( + y_eval[:, -10:].mean(-1), np.arange(batch_size) * 100, rtol=0.1, atol=0.5 + ) + + def test_constant_random(self): + x = AR.dist( + rho=[100, 0, 0], + sigma=0.1, + init_dist=Normal.dist(-100.0, sigma=0.1), + constant=True, + shape=(6,), + ) + x_eval = x.eval() + assert np.allclose(x_eval[:2], -100, rtol=0.1) + assert np.allclose(x_eval[2:], 100, rtol=0.1) + + def test_multivariate_init_dist(self): + init_dist = Dirichlet.dist(a=np.full((5, 2), [1, 10])) + x = AR.dist(rho=[0, 0], init_dist=init_dist, steps=0) + + x_eval = x.eval() + assert x_eval.shape == (5, 2) + + init_dist_eval = init_dist.eval() + init_dist_logp_eval = logp(init_dist, init_dist_eval).eval() + x_logp_eval = logp(x, init_dist_eval).eval() + assert x_logp_eval.shape == (5,) + assert np.allclose(x_logp_eval, init_dist_logp_eval) + + @pytest.mark.parametrize( + "size, expected", + [ + (None, np.full((2, 7), [[2.0], [4.0]])), + ((5, 2), np.full((5, 2, 7), [[2.0], [4.0]])), + ], + ) + def test_moment(self, size, expected): + with Model() as model: + init_dist = Constant.dist([[1.0, 2.0], [3.0, 4.0]]) + AR("x", rho=[0, 0], init_dist=init_dist, steps=7, size=size) + assert_moment_is_expected(model, expected, check_finite_logp=False) @pytest.mark.xfail(reason="Timeseries not refactored")