-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor DensityDist into v4 #5026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,17 +13,18 @@ | |
# limitations under the License. | ||
import contextvars | ||
import functools | ||
import multiprocessing | ||
import sys | ||
import types | ||
import warnings | ||
|
||
from abc import ABCMeta | ||
from functools import singledispatch | ||
from typing import Optional | ||
from typing import Callable, Optional, Sequence | ||
|
||
import aesara | ||
import numpy as np | ||
|
||
from aesara.tensor.basic import as_tensor_variable | ||
from aesara.tensor.random.op import RandomVariable | ||
from aesara.tensor.random.var import RandomStateSharedVariable | ||
from aesara.tensor.var import TensorVariable | ||
|
@@ -41,12 +42,14 @@ | |
maybe_resize, | ||
resize_from_dims, | ||
resize_from_observed, | ||
to_tuple, | ||
) | ||
from pymc.printing import str_for_dist | ||
from pymc.util import UNSET | ||
from pymc.vartypes import string_types | ||
|
||
__all__ = [ | ||
"DensityDistRV", | ||
"DensityDist", | ||
"Distribution", | ||
"Continuous", | ||
|
@@ -387,96 +390,241 @@ class NoDistribution(Distribution): | |
""" | ||
|
||
|
||
class DensityDist(Distribution): | ||
"""Distribution based on a given log density function. | ||
class DensityDistRV(RandomVariable): | ||
""" | ||
Base class for DensityDistRV | ||
|
||
This should be subclassed when defining custom DensityDist objects. | ||
""" | ||
|
||
name = "DensityDistRV" | ||
_print_name = ("DensityDist", "\\operatorname{DensityDist}") | ||
|
||
@classmethod | ||
def rng_fn(cls, rng, *args): | ||
args = list(args) | ||
size = args.pop(-1) | ||
return cls._random_fn(*args, rng=rng, size=size) | ||
|
||
A distribution with the passed log density function is created. | ||
Requires a custom random function passed as kwarg `random` to | ||
enable prior or posterior predictive sampling. | ||
|
||
class DensityDist(NoDistribution): | ||
"""A distribution that can be used to wrap black-box log density functions. | ||
|
||
Creates a Distribution and registers the supplied log density function to be used | ||
for inference. It is also possible to supply a `random` method in order to be able | ||
to sample from the prior or posterior predictive distributions. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
logp, | ||
shape=(), | ||
dtype=None, | ||
initval=0, | ||
random=None, | ||
wrap_random_with_dist_shape=True, | ||
check_shape_in_random=True, | ||
*args, | ||
def __new__( | ||
cls, | ||
name: str, | ||
*dist_params, | ||
logp: Optional[Callable] = None, | ||
logcdf: Optional[Callable] = None, | ||
random: Optional[Callable] = None, | ||
get_moment: Optional[Callable] = None, | ||
ndim_supp: int = 0, | ||
ndims_params: Optional[Sequence[int]] = None, | ||
dtype: str = "floatX", | ||
**kwargs, | ||
): | ||
""" | ||
Parameters | ||
---------- | ||
|
||
logp: callable | ||
A callable that has the following signature ``logp(value)`` and | ||
returns an Aesara tensor that represents the distribution's log | ||
probability density. | ||
shape: tuple (Optional): defaults to `()` | ||
The shape of the distribution. The default value indicates a scalar. | ||
If the distribution is *not* scalar-valued, the programmer should pass | ||
a value here. | ||
dtype: None, str (Optional) | ||
The dtype of the distribution. | ||
initval: number or array (Optional) | ||
The ``initval`` of the RV's tensor that follow the ``DensityDist`` | ||
distribution. | ||
args, kwargs: (Optional) | ||
These are passed to the parent class' ``__init__``. | ||
name : str | ||
dist_params : Tuple | ||
A sequence of the distribution's parameter. These will be converted into | ||
aesara tensors internally. These parameters could be other ``RandomVariable`` | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
instances. | ||
logp : Optional[Callable] | ||
A callable that calculates the log density of some given observed ``value`` | ||
conditioned on certain distribution parameter values. It must have the | ||
following signature: ``logp(value, *dist_params)``, where ``value`` is | ||
an Aesara tensor that represents the observed value, and ``dist_params`` | ||
are the tensors that hold the values of the distribution parameters. | ||
This function must return an Aesara tensor. If ``None``, a ``NotImplemented`` | ||
error will be raised when trying to compute the distribution's logp. | ||
logcdf : Optional[Callable] | ||
A callable that calculates the log cummulative probability of some given observed | ||
``value`` conditioned on certain distribution parameter values. It must have the | ||
following signature: ``logcdf(value, *dist_params)``, where ``value`` is | ||
an Aesara tensor that represents the observed value, and ``dist_params`` | ||
are the tensors that hold the values of the distribution parameters. | ||
This function must return an Aesara tensor. If ``None``, a ``NotImplemented`` | ||
error will be raised when trying to compute the distribution's logcdf. | ||
random : Optional[Callable] | ||
A callable that can be used to generate random draws from the distribution. | ||
It must have the following signature: ``random(*dist_params, rng=None, size=None)``. | ||
The distribution parameters are passed as positional arguments in the | ||
same order as they are supplied when the ``DensityDist`` is constructed. | ||
The keyword arguments are ``rnd``, which will provide the random variable's | ||
associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent | ||
the desired size of the random draw. If ``None``, a ``NotImplemented`` | ||
error will be raised when trying to draw random samples from the distribution's | ||
prior or posterior predictive. | ||
get_moment : Optional[Callable] | ||
A callable that can be used to compute the moments of the distribution. | ||
It must have the following signature: ``get_moment(rv, size, *rv_inputs)``. | ||
The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed | ||
as the first argument ``rv``. ``size`` is the random variable's size implied | ||
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally, | ||
``rv_inputs`` is the sequence of the distribution parameters, in the same order | ||
as they were supplied when the DensityDist was created. If ``None``, a | ||
``NotImplemented`` error will be raised when trying to draw random samples from | ||
the distribution's prior or posterior predictive. | ||
ndim_supp : int | ||
The number of dimensions in the support of the distribution. Defaults to assuming | ||
a scalar distribution, i.e. ``ndim_supp = 0``. | ||
ndims_params : Optional[Sequence[int]] | ||
The list of number of dimensions in the support of each of the distribution's | ||
parameters. If ``None``, it is assumed that all parameters are scalars, hence | ||
the number of dimensions of their support will be 0. | ||
dtype : str | ||
The dtype of the distribution. All draws and observations passed into the distribution | ||
will be casted onto this dtype. | ||
kwargs : | ||
Extra keyword arguments are passed to the parent's class ``__new__`` method. | ||
|
||
Examples | ||
-------- | ||
.. code-block:: python | ||
|
||
def logp(value, mu): | ||
return -(value - mu)**2 | ||
|
||
with pm.Model(): | ||
mu = pm.Normal('mu',0,1) | ||
normal_dist = pm.Normal.dist(mu, 1) | ||
pm.DensityDist( | ||
'density_dist', | ||
normal_dist.logp, | ||
mu, | ||
logp=logp, | ||
observed=np.random.randn(100), | ||
) | ||
idata = pm.sample(100) | ||
|
||
.. code-block:: python | ||
|
||
def logp(value, mu): | ||
return -(value - mu)**2 | ||
|
||
def random(mu, rng=None, size=None): | ||
return rng.normal(loc=mu, scale=1, size=size) | ||
|
||
with pm.Model(): | ||
mu = pm.Normal('mu', 0 , 1) | ||
normal_dist = pm.Normal.dist(mu, 1, shape=3) | ||
dens = pm.DensityDist( | ||
'density_dist', | ||
normal_dist.logp, | ||
mu, | ||
logp=logp, | ||
random=random, | ||
observed=np.random.randn(100, 3), | ||
shape=3, | ||
size=(100, 3), | ||
) | ||
prior = pm.sample_prior_predictive(10)['density_dist'] | ||
assert prior.shape == (10, 100, 3) | ||
|
||
""" | ||
if dtype is None: | ||
|
||
if dist_params is None: | ||
dist_params = [] | ||
elif len(dist_params) > 0 and callable(dist_params[0]): | ||
raise TypeError( | ||
"The DensityDist API has changed, you are using the old API " | ||
"where logp was the first positional argument. In the current API, " | ||
"the logp is a keyword argument, amongst other changes. Please refer " | ||
"to the API documentation for more information on how to use the " | ||
"new DensityDist API." | ||
) | ||
dist_params = [as_tensor_variable(param) for param in dist_params] | ||
|
||
# Assume scalar ndims_params | ||
if ndims_params is None: | ||
ndims_params = [0] * len(dist_params) | ||
|
||
if logp is None: | ||
logp = default_not_implemented(name, "logp") | ||
|
||
if logcdf is None: | ||
logcdf = default_not_implemented(name, "logcdf") | ||
|
||
if random is None: | ||
random = default_not_implemented(name, "random") | ||
|
||
if get_moment is None: | ||
get_moment = default_not_implemented(name, "get_moment") | ||
|
||
rv_op = type( | ||
f"DensityDist_{name}", | ||
(DensityDistRV,), | ||
dict( | ||
name=f"DensityDist_{name}", | ||
inplace=False, | ||
ndim_supp=ndim_supp, | ||
ndims_params=ndims_params, | ||
dtype=dtype, | ||
# Specifc to DensityDist | ||
_random_fn=random, | ||
), | ||
)() | ||
|
||
# Register custom logp | ||
rv_type = type(rv_op) | ||
|
||
@_logp.register(rv_type) | ||
def density_dist_logp(op, rv, rvs_to_values, *dist_params, **kwargs): | ||
value_var = rvs_to_values.get(rv, rv) | ||
return logp( | ||
value_var, | ||
*dist_params, | ||
) | ||
|
||
@_logcdf.register(rv_type) | ||
def density_dist_logcdf(op, var, rvs_to_values, *dist_params, **kwargs): | ||
value_var = rvs_to_values.get(var, var) | ||
return logcdf(value_var, *dist_params, **kwargs) | ||
|
||
@_get_moment.register(rv_type) | ||
def density_dist_get_moment(op, rv, size, *rv_inputs): | ||
return get_moment(rv, size, *rv_inputs) | ||
|
||
cls.rv_op = rv_op | ||
return super().__new__(cls, name, *dist_params, **kwargs) | ||
|
||
@classmethod | ||
def dist(cls, *args, **kwargs): | ||
output = super().dist(args, **kwargs) | ||
if cls.rv_op.dtype == "floatX": | ||
dtype = aesara.config.floatX | ||
super().__init__(shape, dtype, initval, *args, **kwargs) | ||
self.logp = logp | ||
if type(self.logp) == types.MethodType: | ||
if PLATFORM != "linux": | ||
warnings.warn( | ||
"You are passing a bound method as logp for DensityDist, this can lead to " | ||
"errors when sampling on platforms other than Linux. Consider using a " | ||
"plain function instead, or subclass Distribution." | ||
) | ||
elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext: | ||
warnings.warn( | ||
"You are passing a bound method as logp for DensityDist, this can lead to " | ||
"errors when sampling when multiprocessing cannot rely on forking. Consider using a " | ||
"plain function instead, or subclass Distribution." | ||
) | ||
self.rand = random | ||
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape | ||
self.check_shape_in_random = check_shape_in_random | ||
else: | ||
dtype = cls.rv_op.dtype | ||
ndim_supp = cls.rv_op.ndim_supp | ||
if not hasattr(output.tag, "test_value"): | ||
size = to_tuple(kwargs.get("size", None)) + (1,) * ndim_supp | ||
output.tag.test_value = np.zeros(size, dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new initval framework will no longer fall back to the test_value. Furthermore this assignment doesn't have a lot of dimensionality flexibility, for example with symbolic size. My recommendation is to take out the assignment and also the hack below. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but the new initial value framework has not been finished, so I’ve used the same hack that Flat and HalfFlat use to get their initial values. I don’t think that this PR should have to wait for the new initial values framework to merge it. |
||
return output | ||
|
||
|
||
def default_not_implemented(rv_name, method_name): | ||
if method_name == "random": | ||
# This is a hack to catch the NotImplementedError when creating the RV without random | ||
# If the message starts with "Cannot sample from", then it uses the test_value as | ||
# the initial_val. | ||
message = ( | ||
f"Cannot sample from the DensityDist '{rv_name}' because the {method_name} " | ||
"keyword argument was not provided when the distribution was " | ||
f"but this method had not been provided when the distribution was " | ||
f"constructed. Please re-build your model and provide a callable " | ||
f"to '{rv_name}'s {method_name} keyword argument.\n" | ||
) | ||
else: | ||
message = ( | ||
f"Attempted to run {method_name} on the DensityDist '{rv_name}', " | ||
f"but this method had not been provided when the distribution was " | ||
f"constructed. Please re-build your model and provide a callable " | ||
f"to '{rv_name}'s {method_name} keyword argument.\n" | ||
) | ||
|
||
def func(*args, **kwargs): | ||
raise NotImplementedError(message) | ||
|
||
def _distr_parameters_for_repr(self): | ||
return [] | ||
return func |
Uh oh!
There was an error while loading. Please reload this page.