Skip to content

Refactor AR distribution #5734

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@
from pymc.distributions.simulator import Simulator
from pymc.distributions.timeseries import (
AR,
AR1,
GARCH11,
GaussianRandomWalk,
MvGaussianRandomWalk,
Expand Down Expand Up @@ -169,7 +168,6 @@
"WishartBartlett",
"LKJCholeskyCov",
"LKJCorr",
"AR1",
"AR",
"AsymmetricLaplace",
"GaussianRandomWalk",
Expand Down
32 changes: 16 additions & 16 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ def dist(cls, dist, lower, upper, **kwargs):
check_dist_not_registered(dist)
return super().dist([dist, lower, upper], **kwargs)

@classmethod
def num_rngs(cls, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def num_rngs(cls, *args, **kwargs):
def num_rngs(cls, *args, **kwargs):
"""Refer to base class SymbolicDistribution for documentation."""

return 1

@classmethod
def ndim_supp(cls, *dist_params):
Copy link
Member

@canyon289 canyon289 May 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def ndim_supp(cls, *dist_params):
def ndim_supp(cls, *dist_params):
"""Refer to base class SymbolicDistribution for documentation."""

return 0

@classmethod
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):

Expand All @@ -96,24 +104,12 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
rv_out.tag.upper = upper

if rngs is not None:
rv_out = cls.change_rngs(rv_out, rngs)
rv_out = cls._change_rngs(rv_out, rngs)

return rv_out

@classmethod
def ndim_supp(cls, *dist_params):
return 0

@classmethod
def change_size(cls, rv, new_size, expand=False):
dist = rv.tag.dist
lower = rv.tag.lower
upper = rv.tag.upper
new_dist = change_rv_size(dist, new_size, expand=expand)
return cls.rv_op(new_dist, lower, upper)

@classmethod
def change_rngs(cls, rv, new_rngs):
def _change_rngs(cls, rv, new_rngs):
(new_rng,) = new_rngs
dist_node = rv.tag.dist.owner
lower = rv.tag.lower
Expand All @@ -123,8 +119,12 @@ def change_rngs(cls, rv, new_rngs):
return cls.rv_op(new_dist, lower, upper)

@classmethod
def graph_rvs(cls, rv):
return (rv.tag.dist,)
def change_size(cls, rv, new_size, expand=False):
dist = rv.tag.dist
lower = rv.tag.lower
upper = rv.tag.upper
new_dist = change_rv_size(dist, new_size, expand=expand)
return cls.rv_op(new_dist, lower, upper)


@_moment.register(Clip)
Expand Down
71 changes: 37 additions & 34 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,40 @@ def dist(


class SymbolicDistribution:
"""Symbolic statistical distribution

While traditional PyMC distributions are represented by a single RandomVariable
graph, Symbolic distributions correspond to a larger graph that contains one or
more RandomVariables and an arbitrary number of deterministic operations, which
represent their own kind of distribution.

The graphs returned by symbolic distributions can be evaluated directly to
obtain valid draws and can further be parsed by Aeppl to derive the
corresponding logp at runtime.

Check pymc.distributions.Censored for an example of a symbolic distribution.

Symbolic distributions must implement the following classmethods:
cls.dist
Performs input validation and converts optional alternative parametrizations
to a canonical parametrization. It should call `super().dist()`, passing a
list with the default parameters as the first and only non keyword argument,
followed by other keyword arguments like size and rngs, and return the result
cls.num_rngs
Returns the number of rngs given the same arguments passed by the user when
calling the distribution
cls.ndim_supp
Returns the support of the symbolic distribution, given the default set of
parameters. This may not always be constant, for instance if the symbolic
distribution can be defined based on an arbitrary base distribution.
cls.rv_op
Returns a TensorVariable that represents the symbolic distribution
parametrized by a default set of parameters and a size and rngs arguments
cls.change_size
Returns an equivalent symbolic distribution with a different size. This is
analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s.
"""

def __new__(
cls,
name: str,
Expand All @@ -379,36 +413,6 @@ def __new__(
"""Adds a TensorVariable corresponding to a PyMC symbolic distribution to the
current model.

While traditional PyMC distributions are represented by a single RandomVariable
graph, Symbolic distributions correspond to a larger graph that contains one or
more RandomVariables and an arbitrary number of deterministic operations, which
represent their own kind of distribution.

The graphs returned by symbolic distributions can be evaluated directly to
obtain valid draws and can further be parsed by Aeppl to derive the
corresponding logp at runtime.

Check pymc.distributions.Censored for an example of a symbolic distribution.

Symbolic distributions must implement the following classmethods:
cls.dist
Performs input validation and converts optional alternative parametrizations
to a canonical parametrization. It should call `super().dist()`, passing a
list with the default parameters as the first and only non keyword argument,
followed by other keyword arguments like size and rngs, and return the result
cls.rv_op
Returns a TensorVariable that represents the symbolic distribution
parametrized by a default set of parameters and a size and rngs arguments
cls.ndim_supp
Returns the support of the symbolic distribution, given the default
parameters. This may not always be constant, for instance if the symbolic
distribution can be defined based on an arbitrary base distribution.
cls.change_size
Returns an equivalent symbolic distribution with a different size. This is
analogous to `pymc.aesaraf.change_rv_size` for `RandomVariable`s.
cls.graph_rvs
Returns base RVs in a symbolic distribution.

Parameters
----------
cls : type
Expand Down Expand Up @@ -465,9 +469,9 @@ def __new__(
raise TypeError(f"Name needs to be a string but got: {name}")

if rngs is None:
# Create a temporary rv to obtain number of rngs needed
temp_graph = cls.dist(*args, rngs=None, **kwargs)
rngs = [model.next_rng() for _ in cls.graph_rvs(temp_graph)]
# Instead of passing individual RNG variables we could pass a RandomStream
# and let the classes create as many RNGs as they need
rngs = [model.next_rng() for _ in range(cls.num_rngs(*args, **kwargs))]
elif not isinstance(rngs, (list, tuple)):
rngs = [rngs]

Expand Down Expand Up @@ -523,7 +527,6 @@ def dist(
The inputs to the `RandomVariable` `Op`.
shape : int, tuple, Variable, optional
A tuple of sizes for each dimension of the new RV.

An Ellipsis (...) may be inserted in the last position to short-hand refer to
all the dimensions that the RV would get if no shape/size/dims were passed at all.
size : int, tuple, Variable, optional
Expand Down
25 changes: 12 additions & 13 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,18 @@ def dist(cls, w, comp_dists, **kwargs):
w = at.as_tensor_variable(w)
return super().dist([w, *comp_dists], **kwargs)

@classmethod
def num_rngs(cls, w, comp_dists, **kwargs):
if not isinstance(comp_dists, (tuple, list)):
# comp_dists is a single component
comp_dists = [comp_dists]
return len(comp_dists) + 1

@classmethod
def ndim_supp(cls, weights, *components):
# We already checked that all components have the same support dimensionality
return components[0].owner.op.ndim_supp

@classmethod
def rv_op(cls, weights, *components, size=None, rngs=None):
# Update rngs if provided
Expand Down Expand Up @@ -329,11 +341,6 @@ def _resize_components(cls, size, *components):

return [change_rv_size(component, size) for component in components]

@classmethod
def ndim_supp(cls, weights, *components):
# We already checked that all components have the same support dimensionality
return components[0].owner.op.ndim_supp

@classmethod
def change_size(cls, rv, new_size, expand=False):
weights = rv.tag.weights
Expand All @@ -355,14 +362,6 @@ def change_size(cls, rv, new_size, expand=False):

return cls.rv_op(weights, *components, rngs=rngs, size=None)

@classmethod
def graph_rvs(cls, rv):
# We return rv, which is itself a pseudo RandomVariable, that contains a
# mix_indexes_ RV in its inner graph. We want super().dist() to generate
# (components + 1) rngs for us, and it will do so based on how many elements
# we return here
return (*rv.tag.components, rv)


@_get_measurable_outputs.register(MarginalMixtureRV)
def _get_measurable_outputs_MarginalMixtureRV(op, node):
Expand Down
Loading