Skip to content

Commit 8a6b9b3

Browse files
committed
Refactor DensityDist into v4
1 parent 37ba9a3 commit 8a6b9b3

12 files changed

+364
-196
lines changed

RELEASE-NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc/pull/4744)).
1111
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc/pull/4769)).
1212
-`pm.Bound` interface no longer accepts a callable class as argument, instead it requires an instantiated distribution (created via the `.dist()` API) to be passed as an argument. In addition, Bound no longer returns a class instance but works as a normal PyMC distribution. Finally, it is no longer possible to do predictive random sampling from Bounded variables. Please, consult the new documentation for details on how to use Bounded variables (see [4815](https://github.com/pymc-devs/pymc/pull/4815)).
13+
- `pm.DensityDist` no longer accepts the `logp` as its first position argument. It is now an optional keyword argument. If you pass a callable as the first positional argument, a `TypeError` will be raised (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
14+
- `pm.DensityDist` now accepts distribution parameters as positional arguments. Passing them as a dictionary in the `observed` keyword argument is no longer supported and will raise an error (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
15+
- The signature of the `logp` and `random` functions that can be passed into a `pm.DensityDist` has been changed (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
1316
- ...
1417

1518
### New Features
@@ -25,6 +28,8 @@
2528
- The `Polya-Gamma` distribution has been added (see [#4531](https://github.com/pymc-devs/pymc/pull/4531)). To make use of this distribution, the [`polyagamma>=1.3.1`](https://pypi.org/project/polyagamma/) library must be installed and available in the user's environment.
2629
- A small change to the mass matrix tuning methods jitter+adapt_diag (the default) and adapt_diag improves performance early on during tuning for some models. [#5004](https://github.com/pymc-devs/pymc/pull/5004)
2730
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc/pull/5004)
31+
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
32+
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
2833
- ...
2934

3035
### Maintenance

docs/source/Probability_Distributions.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,16 @@ An exponential survival function, where :math:`c=0` denotes failure (or non-surv
5858
f(c, t) = \left\{ \begin{array}{l} \exp(-\lambda t), \text{if c=1} \\
5959
\lambda \exp(-\lambda t), \text{if c=0} \end{array} \right.
6060
61-
Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as an argument to the ``DensityDist`` function, which creates an instance of a PyMC distribution with the custom function as its log-probability.
61+
Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as a keyword argument to the ``DensityDist`` function, which creates an instance of a PyMC3 distribution with the custom function as its log-probability.
6262

6363
For the exponential survival function, this is:
6464

6565
::
6666

67-
def logp(failure, value):
68-
return (failure * log(λ) - λ * value).sum()
67+
def logp(value, t, λ):
68+
return (value * log(λ) - λ * t).sum()
6969

70-
exp_surv = pm.DensityDist('exp_surv', logp, observed={'failure':failure, 'value':t})
70+
exp_surv = pm.DensityDist('exp_surv', t, λ, logp=logp, observed=failure)
7171

7272
Similarly, if a random number generator is required, a function returning random numbers corresponding to the probability distribution can be passed as the ``random`` argument.
7373

pymc/distributions/distribution.py

Lines changed: 207 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
# limitations under the License.
1414
import contextvars
1515
import functools
16-
import multiprocessing
1716
import sys
1817
import types
1918
import warnings
2019

2120
from abc import ABCMeta
2221
from functools import singledispatch
23-
from typing import Optional
22+
from typing import Callable, Optional, Sequence
2423

2524
import aesara
25+
import numpy as np
2626

27+
from aesara.tensor.basic import as_tensor_variable
2728
from aesara.tensor.random.op import RandomVariable
2829
from aesara.tensor.random.var import RandomStateSharedVariable
2930
from aesara.tensor.var import TensorVariable
@@ -41,12 +42,14 @@
4142
maybe_resize,
4243
resize_from_dims,
4344
resize_from_observed,
45+
to_tuple,
4446
)
4547
from pymc.printing import str_for_dist
4648
from pymc.util import UNSET
4749
from pymc.vartypes import string_types
4850

4951
__all__ = [
52+
"DensityDistRV",
5053
"DensityDist",
5154
"Distribution",
5255
"Continuous",
@@ -387,96 +390,241 @@ class NoDistribution(Distribution):
387390
"""
388391

389392

390-
class DensityDist(Distribution):
391-
"""Distribution based on a given log density function.
393+
class DensityDistRV(RandomVariable):
394+
"""
395+
Base class for DensityDistRV
396+
397+
This should be subclassed when defining custom DensityDist objects.
398+
"""
399+
400+
name = "DensityDistRV"
401+
_print_name = ("DensityDist", "\\operatorname{DensityDist}")
402+
403+
@classmethod
404+
def rng_fn(cls, rng, *args):
405+
args = list(args)
406+
size = args.pop(-1)
407+
return cls._random_fn(*args, rng=rng, size=size)
392408

393-
A distribution with the passed log density function is created.
394-
Requires a custom random function passed as kwarg `random` to
395-
enable prior or posterior predictive sampling.
396409

410+
class DensityDist(NoDistribution):
411+
"""A distribution that can be used to wrap black-box log density functions.
412+
413+
Creates a Distribution and registers the supplied log density function to be used
414+
for inference. It is also possible to supply a `random` method in order to be able
415+
to sample from the prior or posterior predictive distributions.
397416
"""
398417

399-
def __init__(
400-
self,
401-
logp,
402-
shape=(),
403-
dtype=None,
404-
initval=0,
405-
random=None,
406-
wrap_random_with_dist_shape=True,
407-
check_shape_in_random=True,
408-
*args,
418+
def __new__(
419+
cls,
420+
name: str,
421+
*dist_params,
422+
logp: Optional[Callable] = None,
423+
logcdf: Optional[Callable] = None,
424+
random: Optional[Callable] = None,
425+
get_moment: Optional[Callable] = None,
426+
ndim_supp: int = 0,
427+
ndims_params: Optional[Sequence[int]] = None,
428+
dtype: str = "floatX",
409429
**kwargs,
410430
):
411431
"""
412432
Parameters
413433
----------
414-
415-
logp: callable
416-
A callable that has the following signature ``logp(value)`` and
417-
returns an Aesara tensor that represents the distribution's log
418-
probability density.
419-
shape: tuple (Optional): defaults to `()`
420-
The shape of the distribution. The default value indicates a scalar.
421-
If the distribution is *not* scalar-valued, the programmer should pass
422-
a value here.
423-
dtype: None, str (Optional)
424-
The dtype of the distribution.
425-
initval: number or array (Optional)
426-
The ``initval`` of the RV's tensor that follow the ``DensityDist``
427-
distribution.
428-
args, kwargs: (Optional)
429-
These are passed to the parent class' ``__init__``.
434+
name : str
435+
dist_params : Tuple
436+
A sequence of the distribution's parameter. These will be converted into
437+
aesara tensors internally. These parameters could be other ``RandomVariable``
438+
instances.
439+
logp : Optional[Callable]
440+
A callable that calculates the log density of some given observed ``value``
441+
conditioned on certain distribution parameter values. It must have the
442+
following signature: ``logp(value, *dist_params)``, where ``value`` is
443+
an Aesara tensor that represents the observed value, and ``dist_params``
444+
are the tensors that hold the values of the distribution parameters.
445+
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
446+
error will be raised when trying to compute the distribution's logp.
447+
logcdf : Optional[Callable]
448+
A callable that calculates the log cummulative probability of some given observed
449+
``value`` conditioned on certain distribution parameter values. It must have the
450+
following signature: ``logcdf(value, *dist_params)``, where ``value`` is
451+
an Aesara tensor that represents the observed value, and ``dist_params``
452+
are the tensors that hold the values of the distribution parameters.
453+
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
454+
error will be raised when trying to compute the distribution's logcdf.
455+
random : Optional[Callable]
456+
A callable that can be used to generate random draws from the distribution.
457+
It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
458+
The distribution parameters are passed as positional arguments in the
459+
same order as they are supplied when the ``DensityDist`` is constructed.
460+
The keyword arguments are ``rnd``, which will provide the random variable's
461+
associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent
462+
the desired size of the random draw. If ``None``, a ``NotImplemented``
463+
error will be raised when trying to draw random samples from the distribution's
464+
prior or posterior predictive.
465+
get_moment : Optional[Callable]
466+
A callable that can be used to compute the moments of the distribution.
467+
It must have the following signature: ``get_moment(rv, size, *rv_inputs)``.
468+
The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed
469+
as the first argument ``rv``. ``size`` is the random variable's size implied
470+
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
471+
``rv_inputs`` is the sequence of the distribution parameters, in the same order
472+
as they were supplied when the DensityDist was created. If ``None``, a
473+
``NotImplemented`` error will be raised when trying to draw random samples from
474+
the distribution's prior or posterior predictive.
475+
ndim_supp : int
476+
The number of dimensions in the support of the distribution. Defaults to assuming
477+
a scalar distribution, i.e. ``ndim_supp = 0``.
478+
ndims_params : Optional[Sequence[int]]
479+
The list of number of dimensions in the support of each of the distribution's
480+
parameters. If ``None``, it is assumed that all parameters are scalars, hence
481+
the number of dimensions of their support will be 0.
482+
dtype : str
483+
The dtype of the distribution. All draws and observations passed into the distribution
484+
will be casted onto this dtype.
485+
kwargs :
486+
Extra keyword arguments are passed to the parent's class ``__new__`` method.
430487
431488
Examples
432489
--------
433490
.. code-block:: python
434491
492+
def logp(value, mu):
493+
return -(value - mu)**2
494+
435495
with pm.Model():
436496
mu = pm.Normal('mu',0,1)
437-
normal_dist = pm.Normal.dist(mu, 1)
438497
pm.DensityDist(
439498
'density_dist',
440-
normal_dist.logp,
499+
mu,
500+
logp=logp,
441501
observed=np.random.randn(100),
442502
)
443503
idata = pm.sample(100)
444504
445505
.. code-block:: python
446506
507+
def logp(value, mu):
508+
return -(value - mu)**2
509+
510+
def random(mu, rng=None, size=None):
511+
return rng.normal(loc=mu, scale=1, size=size)
512+
447513
with pm.Model():
448514
mu = pm.Normal('mu', 0 , 1)
449-
normal_dist = pm.Normal.dist(mu, 1, shape=3)
450515
dens = pm.DensityDist(
451516
'density_dist',
452-
normal_dist.logp,
517+
mu,
518+
logp=logp,
519+
random=random,
453520
observed=np.random.randn(100, 3),
454-
shape=3,
521+
size=(100, 3),
455522
)
456523
prior = pm.sample_prior_predictive(10)['density_dist']
457524
assert prior.shape == (10, 100, 3)
458525
459526
"""
460-
if dtype is None:
527+
528+
if dist_params is None:
529+
dist_params = []
530+
elif len(dist_params) > 0 and callable(dist_params[0]):
531+
raise TypeError(
532+
"The DensityDist API has changed, you are using the old API "
533+
"where logp was the first positional argument. In the current API, "
534+
"the logp is a keyword argument, amongst other changes. Please refer "
535+
"to the API documentation for more information on how to use the "
536+
"new DensityDist API."
537+
)
538+
dist_params = [as_tensor_variable(param) for param in dist_params]
539+
540+
# Assume scalar ndims_params
541+
if ndims_params is None:
542+
ndims_params = [0] * len(dist_params)
543+
544+
if logp is None:
545+
logp = default_not_implemented(name, "logp")
546+
547+
if logcdf is None:
548+
logcdf = default_not_implemented(name, "logcdf")
549+
550+
if random is None:
551+
random = default_not_implemented(name, "random")
552+
553+
if get_moment is None:
554+
get_moment = default_not_implemented(name, "get_moment")
555+
556+
rv_op = type(
557+
f"DensityDist_{name}",
558+
(DensityDistRV,),
559+
dict(
560+
name=f"DensityDist_{name}",
561+
inplace=False,
562+
ndim_supp=ndim_supp,
563+
ndims_params=ndims_params,
564+
dtype=dtype,
565+
# Specifc to DensityDist
566+
_random_fn=random,
567+
),
568+
)()
569+
570+
# Register custom logp
571+
rv_type = type(rv_op)
572+
573+
@_logp.register(rv_type)
574+
def density_dist_logp(op, rv, rvs_to_values, *dist_params, **kwargs):
575+
value_var = rvs_to_values.get(rv, rv)
576+
return logp(
577+
value_var,
578+
*dist_params,
579+
)
580+
581+
@_logcdf.register(rv_type)
582+
def density_dist_logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
583+
value_var = rvs_to_values.get(var, var)
584+
return logcdf(value_var, *dist_params, **kwargs)
585+
586+
@_get_moment.register(rv_type)
587+
def density_dist_get_moment(op, rv, size, *rv_inputs):
588+
return get_moment(rv, size, *rv_inputs)
589+
590+
cls.rv_op = rv_op
591+
return super().__new__(cls, name, *dist_params, **kwargs)
592+
593+
@classmethod
594+
def dist(cls, *args, **kwargs):
595+
output = super().dist(args, **kwargs)
596+
if cls.rv_op.dtype == "floatX":
461597
dtype = aesara.config.floatX
462-
super().__init__(shape, dtype, initval, *args, **kwargs)
463-
self.logp = logp
464-
if type(self.logp) == types.MethodType:
465-
if PLATFORM != "linux":
466-
warnings.warn(
467-
"You are passing a bound method as logp for DensityDist, this can lead to "
468-
"errors when sampling on platforms other than Linux. Consider using a "
469-
"plain function instead, or subclass Distribution."
470-
)
471-
elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext:
472-
warnings.warn(
473-
"You are passing a bound method as logp for DensityDist, this can lead to "
474-
"errors when sampling when multiprocessing cannot rely on forking. Consider using a "
475-
"plain function instead, or subclass Distribution."
476-
)
477-
self.rand = random
478-
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
479-
self.check_shape_in_random = check_shape_in_random
598+
else:
599+
dtype = cls.rv_op.dtype
600+
ndim_supp = cls.rv_op.ndim_supp
601+
if not hasattr(output.tag, "test_value"):
602+
size = to_tuple(kwargs.get("size", None)) + (1,) * ndim_supp
603+
output.tag.test_value = np.zeros(size, dtype)
604+
return output
605+
606+
607+
def default_not_implemented(rv_name, method_name):
608+
if method_name == "random":
609+
# This is a hack to catch the NotImplementedError when creating the RV without random
610+
# If the message starts with "Cannot sample from", then it uses the test_value as
611+
# the initial_val.
612+
message = (
613+
f"Cannot sample from the DensityDist '{rv_name}' because the {method_name} "
614+
"keyword argument was not provided when the distribution was "
615+
f"but this method had not been provided when the distribution was "
616+
f"constructed. Please re-build your model and provide a callable "
617+
f"to '{rv_name}'s {method_name} keyword argument.\n"
618+
)
619+
else:
620+
message = (
621+
f"Attempted to run {method_name} on the DensityDist '{rv_name}', "
622+
f"but this method had not been provided when the distribution was "
623+
f"constructed. Please re-build your model and provide a callable "
624+
f"to '{rv_name}'s {method_name} keyword argument.\n"
625+
)
626+
627+
def func(*args, **kwargs):
628+
raise NotImplementedError(message)
480629

481-
def _distr_parameters_for_repr(self):
482-
return []
630+
return func

0 commit comments

Comments
 (0)