Skip to content

Commit 39eceb7

Browse files
lucianopaztwiecki
authored andcommitted
Refactor DensityDist into v4
1 parent 00e6eb9 commit 39eceb7

12 files changed

+364
-196
lines changed

RELEASE-NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc/pull/4744)).
1111
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc/pull/4769)).
1212
-`pm.Bound` interface no longer accepts a callable class as argument, instead it requires an instantiated distribution (created via the `.dist()` API) to be passed as an argument. In addition, Bound no longer returns a class instance but works as a normal PyMC distribution. Finally, it is no longer possible to do predictive random sampling from Bounded variables. Please, consult the new documentation for details on how to use Bounded variables (see [4815](https://github.com/pymc-devs/pymc/pull/4815)).
13+
- `pm.DensityDist` no longer accepts the `logp` as its first position argument. It is now an optional keyword argument. If you pass a callable as the first positional argument, a `TypeError` will be raised (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
14+
- `pm.DensityDist` now accepts distribution parameters as positional arguments. Passing them as a dictionary in the `observed` keyword argument is no longer supported and will raise an error (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
15+
- The signature of the `logp` and `random` functions that can be passed into a `pm.DensityDist` has been changed (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
1316
- ...
1417

1518
### New Features
@@ -25,6 +28,8 @@
2528
- The `Polya-Gamma` distribution has been added (see [#4531](https://github.com/pymc-devs/pymc/pull/4531)). To make use of this distribution, the [`polyagamma>=1.3.1`](https://pypi.org/project/polyagamma/) library must be installed and available in the user's environment.
2629
- A small change to the mass matrix tuning methods jitter+adapt_diag (the default) and adapt_diag improves performance early on during tuning for some models. [#5004](https://github.com/pymc-devs/pymc/pull/5004)
2730
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc/pull/5004)
31+
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
32+
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
2833
- ...
2934

3035
### Maintenance

docs/source/Probability_Distributions.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,16 @@ An exponential survival function, where :math:`c=0` denotes failure (or non-surv
5858
f(c, t) = \left\{ \begin{array}{l} \exp(-\lambda t), \text{if c=1} \\
5959
\lambda \exp(-\lambda t), \text{if c=0} \end{array} \right.
6060
61-
Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as an argument to the ``DensityDist`` function, which creates an instance of a PyMC distribution with the custom function as its log-probability.
61+
Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as a keyword argument to the ``DensityDist`` function, which creates an instance of a PyMC3 distribution with the custom function as its log-probability.
6262

6363
For the exponential survival function, this is:
6464

6565
::
6666

67-
def logp(failure, value):
68-
return (failure * log(λ) - λ * value).sum()
67+
def logp(value, t, λ):
68+
return (value * log(λ) - λ * t).sum()
6969

70-
exp_surv = pm.DensityDist('exp_surv', logp, observed={'failure':failure, 'value':t})
70+
exp_surv = pm.DensityDist('exp_surv', t, λ, logp=logp, observed=failure)
7171

7272
Similarly, if a random number generator is required, a function returning random numbers corresponding to the probability distribution can be passed as the ``random`` argument.
7373

pymc/distributions/distribution.py

Lines changed: 207 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@
1313
# limitations under the License.
1414
import contextvars
1515
import functools
16-
import multiprocessing
1716
import sys
1817
import types
1918
import warnings
2019

2120
from abc import ABCMeta
2221
from functools import singledispatch
23-
from typing import Optional
22+
from typing import Callable, Optional, Sequence
2423

2524
import aesara
25+
import numpy as np
2626

27+
from aesara.tensor.basic import as_tensor_variable
2728
from aesara.tensor.random.op import RandomVariable
2829
from aesara.tensor.random.var import RandomStateSharedVariable
2930
from aesara.tensor.var import TensorVariable
@@ -41,12 +42,14 @@
4142
maybe_resize,
4243
resize_from_dims,
4344
resize_from_observed,
45+
to_tuple,
4446
)
4547
from pymc.printing import str_for_dist
4648
from pymc.util import UNSET
4749
from pymc.vartypes import string_types
4850

4951
__all__ = [
52+
"DensityDistRV",
5053
"DensityDist",
5154
"Distribution",
5255
"Continuous",
@@ -389,96 +392,241 @@ class NoDistribution(Distribution):
389392
"""
390393

391394

392-
class DensityDist(Distribution):
393-
"""Distribution based on a given log density function.
395+
class DensityDistRV(RandomVariable):
396+
"""
397+
Base class for DensityDistRV
398+
399+
This should be subclassed when defining custom DensityDist objects.
400+
"""
401+
402+
name = "DensityDistRV"
403+
_print_name = ("DensityDist", "\\operatorname{DensityDist}")
404+
405+
@classmethod
406+
def rng_fn(cls, rng, *args):
407+
args = list(args)
408+
size = args.pop(-1)
409+
return cls._random_fn(*args, rng=rng, size=size)
394410

395-
A distribution with the passed log density function is created.
396-
Requires a custom random function passed as kwarg `random` to
397-
enable prior or posterior predictive sampling.
398411

412+
class DensityDist(NoDistribution):
413+
"""A distribution that can be used to wrap black-box log density functions.
414+
415+
Creates a Distribution and registers the supplied log density function to be used
416+
for inference. It is also possible to supply a `random` method in order to be able
417+
to sample from the prior or posterior predictive distributions.
399418
"""
400419

401-
def __init__(
402-
self,
403-
logp,
404-
shape=(),
405-
dtype=None,
406-
initval=0,
407-
random=None,
408-
wrap_random_with_dist_shape=True,
409-
check_shape_in_random=True,
410-
*args,
420+
def __new__(
421+
cls,
422+
name: str,
423+
*dist_params,
424+
logp: Optional[Callable] = None,
425+
logcdf: Optional[Callable] = None,
426+
random: Optional[Callable] = None,
427+
get_moment: Optional[Callable] = None,
428+
ndim_supp: int = 0,
429+
ndims_params: Optional[Sequence[int]] = None,
430+
dtype: str = "floatX",
411431
**kwargs,
412432
):
413433
"""
414434
Parameters
415435
----------
416-
417-
logp: callable
418-
A callable that has the following signature ``logp(value)`` and
419-
returns an Aesara tensor that represents the distribution's log
420-
probability density.
421-
shape: tuple (Optional): defaults to `()`
422-
The shape of the distribution. The default value indicates a scalar.
423-
If the distribution is *not* scalar-valued, the programmer should pass
424-
a value here.
425-
dtype: None, str (Optional)
426-
The dtype of the distribution.
427-
initval: number or array (Optional)
428-
The ``initval`` of the RV's tensor that follow the ``DensityDist``
429-
distribution.
430-
args, kwargs: (Optional)
431-
These are passed to the parent class' ``__init__``.
436+
name : str
437+
dist_params : Tuple
438+
A sequence of the distribution's parameter. These will be converted into
439+
aesara tensors internally. These parameters could be other ``RandomVariable``
440+
instances.
441+
logp : Optional[Callable]
442+
A callable that calculates the log density of some given observed ``value``
443+
conditioned on certain distribution parameter values. It must have the
444+
following signature: ``logp(value, *dist_params)``, where ``value`` is
445+
an Aesara tensor that represents the observed value, and ``dist_params``
446+
are the tensors that hold the values of the distribution parameters.
447+
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
448+
error will be raised when trying to compute the distribution's logp.
449+
logcdf : Optional[Callable]
450+
A callable that calculates the log cummulative probability of some given observed
451+
``value`` conditioned on certain distribution parameter values. It must have the
452+
following signature: ``logcdf(value, *dist_params)``, where ``value`` is
453+
an Aesara tensor that represents the observed value, and ``dist_params``
454+
are the tensors that hold the values of the distribution parameters.
455+
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
456+
error will be raised when trying to compute the distribution's logcdf.
457+
random : Optional[Callable]
458+
A callable that can be used to generate random draws from the distribution.
459+
It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
460+
The distribution parameters are passed as positional arguments in the
461+
same order as they are supplied when the ``DensityDist`` is constructed.
462+
The keyword arguments are ``rnd``, which will provide the random variable's
463+
associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent
464+
the desired size of the random draw. If ``None``, a ``NotImplemented``
465+
error will be raised when trying to draw random samples from the distribution's
466+
prior or posterior predictive.
467+
get_moment : Optional[Callable]
468+
A callable that can be used to compute the moments of the distribution.
469+
It must have the following signature: ``get_moment(rv, size, *rv_inputs)``.
470+
The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed
471+
as the first argument ``rv``. ``size`` is the random variable's size implied
472+
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
473+
``rv_inputs`` is the sequence of the distribution parameters, in the same order
474+
as they were supplied when the DensityDist was created. If ``None``, a
475+
``NotImplemented`` error will be raised when trying to draw random samples from
476+
the distribution's prior or posterior predictive.
477+
ndim_supp : int
478+
The number of dimensions in the support of the distribution. Defaults to assuming
479+
a scalar distribution, i.e. ``ndim_supp = 0``.
480+
ndims_params : Optional[Sequence[int]]
481+
The list of number of dimensions in the support of each of the distribution's
482+
parameters. If ``None``, it is assumed that all parameters are scalars, hence
483+
the number of dimensions of their support will be 0.
484+
dtype : str
485+
The dtype of the distribution. All draws and observations passed into the distribution
486+
will be casted onto this dtype.
487+
kwargs :
488+
Extra keyword arguments are passed to the parent's class ``__new__`` method.
432489
433490
Examples
434491
--------
435492
.. code-block:: python
436493
494+
def logp(value, mu):
495+
return -(value - mu)**2
496+
437497
with pm.Model():
438498
mu = pm.Normal('mu',0,1)
439-
normal_dist = pm.Normal.dist(mu, 1)
440499
pm.DensityDist(
441500
'density_dist',
442-
normal_dist.logp,
501+
mu,
502+
logp=logp,
443503
observed=np.random.randn(100),
444504
)
445505
idata = pm.sample(100)
446506
447507
.. code-block:: python
448508
509+
def logp(value, mu):
510+
return -(value - mu)**2
511+
512+
def random(mu, rng=None, size=None):
513+
return rng.normal(loc=mu, scale=1, size=size)
514+
449515
with pm.Model():
450516
mu = pm.Normal('mu', 0 , 1)
451-
normal_dist = pm.Normal.dist(mu, 1, shape=3)
452517
dens = pm.DensityDist(
453518
'density_dist',
454-
normal_dist.logp,
519+
mu,
520+
logp=logp,
521+
random=random,
455522
observed=np.random.randn(100, 3),
456-
shape=3,
523+
size=(100, 3),
457524
)
458525
prior = pm.sample_prior_predictive(10)['density_dist']
459526
assert prior.shape == (10, 100, 3)
460527
461528
"""
462-
if dtype is None:
529+
530+
if dist_params is None:
531+
dist_params = []
532+
elif len(dist_params) > 0 and callable(dist_params[0]):
533+
raise TypeError(
534+
"The DensityDist API has changed, you are using the old API "
535+
"where logp was the first positional argument. In the current API, "
536+
"the logp is a keyword argument, amongst other changes. Please refer "
537+
"to the API documentation for more information on how to use the "
538+
"new DensityDist API."
539+
)
540+
dist_params = [as_tensor_variable(param) for param in dist_params]
541+
542+
# Assume scalar ndims_params
543+
if ndims_params is None:
544+
ndims_params = [0] * len(dist_params)
545+
546+
if logp is None:
547+
logp = default_not_implemented(name, "logp")
548+
549+
if logcdf is None:
550+
logcdf = default_not_implemented(name, "logcdf")
551+
552+
if random is None:
553+
random = default_not_implemented(name, "random")
554+
555+
if get_moment is None:
556+
get_moment = default_not_implemented(name, "get_moment")
557+
558+
rv_op = type(
559+
f"DensityDist_{name}",
560+
(DensityDistRV,),
561+
dict(
562+
name=f"DensityDist_{name}",
563+
inplace=False,
564+
ndim_supp=ndim_supp,
565+
ndims_params=ndims_params,
566+
dtype=dtype,
567+
# Specifc to DensityDist
568+
_random_fn=random,
569+
),
570+
)()
571+
572+
# Register custom logp
573+
rv_type = type(rv_op)
574+
575+
@_logp.register(rv_type)
576+
def density_dist_logp(op, rv, rvs_to_values, *dist_params, **kwargs):
577+
value_var = rvs_to_values.get(rv, rv)
578+
return logp(
579+
value_var,
580+
*dist_params,
581+
)
582+
583+
@_logcdf.register(rv_type)
584+
def density_dist_logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
585+
value_var = rvs_to_values.get(var, var)
586+
return logcdf(value_var, *dist_params, **kwargs)
587+
588+
@_get_moment.register(rv_type)
589+
def density_dist_get_moment(op, rv, size, *rv_inputs):
590+
return get_moment(rv, size, *rv_inputs)
591+
592+
cls.rv_op = rv_op
593+
return super().__new__(cls, name, *dist_params, **kwargs)
594+
595+
@classmethod
596+
def dist(cls, *args, **kwargs):
597+
output = super().dist(args, **kwargs)
598+
if cls.rv_op.dtype == "floatX":
463599
dtype = aesara.config.floatX
464-
super().__init__(shape, dtype, initval, *args, **kwargs)
465-
self.logp = logp
466-
if type(self.logp) == types.MethodType:
467-
if PLATFORM != "linux":
468-
warnings.warn(
469-
"You are passing a bound method as logp for DensityDist, this can lead to "
470-
"errors when sampling on platforms other than Linux. Consider using a "
471-
"plain function instead, or subclass Distribution."
472-
)
473-
elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext:
474-
warnings.warn(
475-
"You are passing a bound method as logp for DensityDist, this can lead to "
476-
"errors when sampling when multiprocessing cannot rely on forking. Consider using a "
477-
"plain function instead, or subclass Distribution."
478-
)
479-
self.rand = random
480-
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
481-
self.check_shape_in_random = check_shape_in_random
600+
else:
601+
dtype = cls.rv_op.dtype
602+
ndim_supp = cls.rv_op.ndim_supp
603+
if not hasattr(output.tag, "test_value"):
604+
size = to_tuple(kwargs.get("size", None)) + (1,) * ndim_supp
605+
output.tag.test_value = np.zeros(size, dtype)
606+
return output
607+
608+
609+
def default_not_implemented(rv_name, method_name):
610+
if method_name == "random":
611+
# This is a hack to catch the NotImplementedError when creating the RV without random
612+
# If the message starts with "Cannot sample from", then it uses the test_value as
613+
# the initial_val.
614+
message = (
615+
f"Cannot sample from the DensityDist '{rv_name}' because the {method_name} "
616+
"keyword argument was not provided when the distribution was "
617+
f"but this method had not been provided when the distribution was "
618+
f"constructed. Please re-build your model and provide a callable "
619+
f"to '{rv_name}'s {method_name} keyword argument.\n"
620+
)
621+
else:
622+
message = (
623+
f"Attempted to run {method_name} on the DensityDist '{rv_name}', "
624+
f"but this method had not been provided when the distribution was "
625+
f"constructed. Please re-build your model and provide a callable "
626+
f"to '{rv_name}'s {method_name} keyword argument.\n"
627+
)
628+
629+
def func(*args, **kwargs):
630+
raise NotImplementedError(message)
482631

483-
def _distr_parameters_for_repr(self):
484-
return []
632+
return func

0 commit comments

Comments
 (0)