|
13 | 13 | # limitations under the License.
|
14 | 14 | import contextvars
|
15 | 15 | import functools
|
16 |
| -import multiprocessing |
17 | 16 | import sys
|
18 | 17 | import types
|
19 | 18 | import warnings
|
20 | 19 |
|
21 | 20 | from abc import ABCMeta
|
22 | 21 | from functools import singledispatch
|
23 |
| -from typing import Optional |
| 22 | +from typing import Callable, Optional, Sequence |
24 | 23 |
|
25 | 24 | import aesara
|
| 25 | +import numpy as np |
26 | 26 |
|
| 27 | +from aesara.tensor.basic import as_tensor_variable |
27 | 28 | from aesara.tensor.random.op import RandomVariable
|
28 | 29 | from aesara.tensor.random.var import RandomStateSharedVariable
|
29 | 30 | from aesara.tensor.var import TensorVariable
|
|
41 | 42 | maybe_resize,
|
42 | 43 | resize_from_dims,
|
43 | 44 | resize_from_observed,
|
| 45 | + to_tuple, |
44 | 46 | )
|
45 | 47 | from pymc.printing import str_for_dist
|
46 | 48 | from pymc.util import UNSET
|
47 | 49 | from pymc.vartypes import string_types
|
48 | 50 |
|
49 | 51 | __all__ = [
|
| 52 | + "DensityDistRV", |
50 | 53 | "DensityDist",
|
51 | 54 | "Distribution",
|
52 | 55 | "Continuous",
|
@@ -387,96 +390,241 @@ class NoDistribution(Distribution):
|
387 | 390 | """
|
388 | 391 |
|
389 | 392 |
|
390 |
| -class DensityDist(Distribution): |
391 |
| - """Distribution based on a given log density function. |
| 393 | +class DensityDistRV(RandomVariable): |
| 394 | + """ |
| 395 | + Base class for DensityDistRV |
| 396 | +
|
| 397 | + This should be subclassed when defining custom DensityDist objects. |
| 398 | + """ |
| 399 | + |
| 400 | + name = "DensityDistRV" |
| 401 | + _print_name = ("DensityDist", "\\operatorname{DensityDist}") |
| 402 | + |
| 403 | + @classmethod |
| 404 | + def rng_fn(cls, rng, *args): |
| 405 | + args = list(args) |
| 406 | + size = args.pop(-1) |
| 407 | + return cls._random_fn(*args, rng=rng, size=size) |
392 | 408 |
|
393 |
| - A distribution with the passed log density function is created. |
394 |
| - Requires a custom random function passed as kwarg `random` to |
395 |
| - enable prior or posterior predictive sampling. |
396 | 409 |
|
| 410 | +class DensityDist(NoDistribution): |
| 411 | + """A distribution that can be used to wrap black-box log density functions. |
| 412 | +
|
| 413 | + Creates a Distribution and registers the supplied log density function to be used |
| 414 | + for inference. It is also possible to supply a `random` method in order to be able |
| 415 | + to sample from the prior or posterior predictive distributions. |
397 | 416 | """
|
398 | 417 |
|
399 |
| - def __init__( |
400 |
| - self, |
401 |
| - logp, |
402 |
| - shape=(), |
403 |
| - dtype=None, |
404 |
| - initval=0, |
405 |
| - random=None, |
406 |
| - wrap_random_with_dist_shape=True, |
407 |
| - check_shape_in_random=True, |
408 |
| - *args, |
| 418 | + def __new__( |
| 419 | + cls, |
| 420 | + name: str, |
| 421 | + *dist_params, |
| 422 | + logp: Optional[Callable] = None, |
| 423 | + logcdf: Optional[Callable] = None, |
| 424 | + random: Optional[Callable] = None, |
| 425 | + get_moment: Optional[Callable] = None, |
| 426 | + ndim_supp: int = 0, |
| 427 | + ndims_params: Optional[Sequence[int]] = None, |
| 428 | + dtype: str = "floatX", |
409 | 429 | **kwargs,
|
410 | 430 | ):
|
411 | 431 | """
|
412 | 432 | Parameters
|
413 | 433 | ----------
|
414 |
| -
|
415 |
| - logp: callable |
416 |
| - A callable that has the following signature ``logp(value)`` and |
417 |
| - returns an Aesara tensor that represents the distribution's log |
418 |
| - probability density. |
419 |
| - shape: tuple (Optional): defaults to `()` |
420 |
| - The shape of the distribution. The default value indicates a scalar. |
421 |
| - If the distribution is *not* scalar-valued, the programmer should pass |
422 |
| - a value here. |
423 |
| - dtype: None, str (Optional) |
424 |
| - The dtype of the distribution. |
425 |
| - initval: number or array (Optional) |
426 |
| - The ``initval`` of the RV's tensor that follow the ``DensityDist`` |
427 |
| - distribution. |
428 |
| - args, kwargs: (Optional) |
429 |
| - These are passed to the parent class' ``__init__``. |
| 434 | + name : str |
| 435 | + dist_params : Tuple |
| 436 | + A sequence of the distribution's parameter. These will be converted into |
| 437 | + aesara tensors internally. These parameters could be other ``RandomVariable`` |
| 438 | + instances. |
| 439 | + logp : Optional[Callable] |
| 440 | + A callable that calculates the log density of some given observed ``value`` |
| 441 | + conditioned on certain distribution parameter values. It must have the |
| 442 | + following signature: ``logp(value, *dist_params)``, where ``value`` is |
| 443 | + an Aesara tensor that represents the observed value, and ``dist_params`` |
| 444 | + are the tensors that hold the values of the distribution parameters. |
| 445 | + This function must return an Aesara tensor. If ``None``, a ``NotImplemented`` |
| 446 | + error will be raised when trying to compute the distribution's logp. |
| 447 | + logcdf : Optional[Callable] |
| 448 | + A callable that calculates the log cummulative probability of some given observed |
| 449 | + ``value`` conditioned on certain distribution parameter values. It must have the |
| 450 | + following signature: ``logcdf(value, *dist_params)``, where ``value`` is |
| 451 | + an Aesara tensor that represents the observed value, and ``dist_params`` |
| 452 | + are the tensors that hold the values of the distribution parameters. |
| 453 | + This function must return an Aesara tensor. If ``None``, a ``NotImplemented`` |
| 454 | + error will be raised when trying to compute the distribution's logcdf. |
| 455 | + random : Optional[Callable] |
| 456 | + A callable that can be used to generate random draws from the distribution. |
| 457 | + It must have the following signature: ``random(*dist_params, rng=None, size=None)``. |
| 458 | + The distribution parameters are passed as positional arguments in the |
| 459 | + same order as they are supplied when the ``DensityDist`` is constructed. |
| 460 | + The keyword arguments are ``rnd``, which will provide the random variable's |
| 461 | + associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent |
| 462 | + the desired size of the random draw. If ``None``, a ``NotImplemented`` |
| 463 | + error will be raised when trying to draw random samples from the distribution's |
| 464 | + prior or posterior predictive. |
| 465 | + get_moment : Optional[Callable] |
| 466 | + A callable that can be used to compute the moments of the distribution. |
| 467 | + It must have the following signature: ``get_moment(rv, size, *rv_inputs)``. |
| 468 | + The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed |
| 469 | + as the first argument ``rv``. ``size`` is the random variable's size implied |
| 470 | + by the ``dims``, ``size`` and parameters supplied to the distribution. Finally, |
| 471 | + ``rv_inputs`` is the sequence of the distribution parameters, in the same order |
| 472 | + as they were supplied when the DensityDist was created. If ``None``, a |
| 473 | + ``NotImplemented`` error will be raised when trying to draw random samples from |
| 474 | + the distribution's prior or posterior predictive. |
| 475 | + ndim_supp : int |
| 476 | + The number of dimensions in the support of the distribution. Defaults to assuming |
| 477 | + a scalar distribution, i.e. ``ndim_supp = 0``. |
| 478 | + ndims_params : Optional[Sequence[int]] |
| 479 | + The list of number of dimensions in the support of each of the distribution's |
| 480 | + parameters. If ``None``, it is assumed that all parameters are scalars, hence |
| 481 | + the number of dimensions of their support will be 0. |
| 482 | + dtype : str |
| 483 | + The dtype of the distribution. All draws and observations passed into the distribution |
| 484 | + will be casted onto this dtype. |
| 485 | + kwargs : |
| 486 | + Extra keyword arguments are passed to the parent's class ``__new__`` method. |
430 | 487 |
|
431 | 488 | Examples
|
432 | 489 | --------
|
433 | 490 | .. code-block:: python
|
434 | 491 |
|
| 492 | + def logp(value, mu): |
| 493 | + return -(value - mu)**2 |
| 494 | +
|
435 | 495 | with pm.Model():
|
436 | 496 | mu = pm.Normal('mu',0,1)
|
437 |
| - normal_dist = pm.Normal.dist(mu, 1) |
438 | 497 | pm.DensityDist(
|
439 | 498 | 'density_dist',
|
440 |
| - normal_dist.logp, |
| 499 | + mu, |
| 500 | + logp=logp, |
441 | 501 | observed=np.random.randn(100),
|
442 | 502 | )
|
443 | 503 | idata = pm.sample(100)
|
444 | 504 |
|
445 | 505 | .. code-block:: python
|
446 | 506 |
|
| 507 | + def logp(value, mu): |
| 508 | + return -(value - mu)**2 |
| 509 | +
|
| 510 | + def random(mu, rng=None, size=None): |
| 511 | + return rng.normal(loc=mu, scale=1, size=size) |
| 512 | +
|
447 | 513 | with pm.Model():
|
448 | 514 | mu = pm.Normal('mu', 0 , 1)
|
449 |
| - normal_dist = pm.Normal.dist(mu, 1, shape=3) |
450 | 515 | dens = pm.DensityDist(
|
451 | 516 | 'density_dist',
|
452 |
| - normal_dist.logp, |
| 517 | + mu, |
| 518 | + logp=logp, |
| 519 | + random=random, |
453 | 520 | observed=np.random.randn(100, 3),
|
454 |
| - shape=3, |
| 521 | + size=(100, 3), |
455 | 522 | )
|
456 | 523 | prior = pm.sample_prior_predictive(10)['density_dist']
|
457 | 524 | assert prior.shape == (10, 100, 3)
|
458 | 525 |
|
459 | 526 | """
|
460 |
| - if dtype is None: |
| 527 | + |
| 528 | + if dist_params is None: |
| 529 | + dist_params = [] |
| 530 | + elif len(dist_params) > 0 and callable(dist_params[0]): |
| 531 | + raise TypeError( |
| 532 | + "The DensityDist API has changed, you are using the old API " |
| 533 | + "where logp was the first positional argument. In the current API, " |
| 534 | + "the logp is a keyword argument, amongst other changes. Please refer " |
| 535 | + "to the API documentation for more information on how to use the " |
| 536 | + "new DensityDist API." |
| 537 | + ) |
| 538 | + dist_params = [as_tensor_variable(param) for param in dist_params] |
| 539 | + |
| 540 | + # Assume scalar ndims_params |
| 541 | + if ndims_params is None: |
| 542 | + ndims_params = [0] * len(dist_params) |
| 543 | + |
| 544 | + if logp is None: |
| 545 | + logp = default_not_implemented(name, "logp") |
| 546 | + |
| 547 | + if logcdf is None: |
| 548 | + logcdf = default_not_implemented(name, "logcdf") |
| 549 | + |
| 550 | + if random is None: |
| 551 | + random = default_not_implemented(name, "random") |
| 552 | + |
| 553 | + if get_moment is None: |
| 554 | + get_moment = default_not_implemented(name, "get_moment") |
| 555 | + |
| 556 | + rv_op = type( |
| 557 | + f"DensityDist_{name}", |
| 558 | + (DensityDistRV,), |
| 559 | + dict( |
| 560 | + name=f"DensityDist_{name}", |
| 561 | + inplace=False, |
| 562 | + ndim_supp=ndim_supp, |
| 563 | + ndims_params=ndims_params, |
| 564 | + dtype=dtype, |
| 565 | + # Specifc to DensityDist |
| 566 | + _random_fn=random, |
| 567 | + ), |
| 568 | + )() |
| 569 | + |
| 570 | + # Register custom logp |
| 571 | + rv_type = type(rv_op) |
| 572 | + |
| 573 | + @_logp.register(rv_type) |
| 574 | + def density_dist_logp(op, rv, rvs_to_values, *dist_params, **kwargs): |
| 575 | + value_var = rvs_to_values.get(rv, rv) |
| 576 | + return logp( |
| 577 | + value_var, |
| 578 | + *dist_params, |
| 579 | + ) |
| 580 | + |
| 581 | + @_logcdf.register(rv_type) |
| 582 | + def density_dist_logcdf(op, var, rvs_to_values, *dist_params, **kwargs): |
| 583 | + value_var = rvs_to_values.get(var, var) |
| 584 | + return logcdf(value_var, *dist_params, **kwargs) |
| 585 | + |
| 586 | + @_get_moment.register(rv_type) |
| 587 | + def density_dist_get_moment(op, rv, size, *rv_inputs): |
| 588 | + return get_moment(rv, size, *rv_inputs) |
| 589 | + |
| 590 | + cls.rv_op = rv_op |
| 591 | + return super().__new__(cls, name, *dist_params, **kwargs) |
| 592 | + |
| 593 | + @classmethod |
| 594 | + def dist(cls, *args, **kwargs): |
| 595 | + output = super().dist(args, **kwargs) |
| 596 | + if cls.rv_op.dtype == "floatX": |
461 | 597 | dtype = aesara.config.floatX
|
462 |
| - super().__init__(shape, dtype, initval, *args, **kwargs) |
463 |
| - self.logp = logp |
464 |
| - if type(self.logp) == types.MethodType: |
465 |
| - if PLATFORM != "linux": |
466 |
| - warnings.warn( |
467 |
| - "You are passing a bound method as logp for DensityDist, this can lead to " |
468 |
| - "errors when sampling on platforms other than Linux. Consider using a " |
469 |
| - "plain function instead, or subclass Distribution." |
470 |
| - ) |
471 |
| - elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext: |
472 |
| - warnings.warn( |
473 |
| - "You are passing a bound method as logp for DensityDist, this can lead to " |
474 |
| - "errors when sampling when multiprocessing cannot rely on forking. Consider using a " |
475 |
| - "plain function instead, or subclass Distribution." |
476 |
| - ) |
477 |
| - self.rand = random |
478 |
| - self.wrap_random_with_dist_shape = wrap_random_with_dist_shape |
479 |
| - self.check_shape_in_random = check_shape_in_random |
| 598 | + else: |
| 599 | + dtype = cls.rv_op.dtype |
| 600 | + ndim_supp = cls.rv_op.ndim_supp |
| 601 | + if not hasattr(output.tag, "test_value"): |
| 602 | + size = to_tuple(kwargs.get("size", None)) + (1,) * ndim_supp |
| 603 | + output.tag.test_value = np.zeros(size, dtype) |
| 604 | + return output |
| 605 | + |
| 606 | + |
| 607 | +def default_not_implemented(rv_name, method_name): |
| 608 | + if method_name == "random": |
| 609 | + # This is a hack to catch the NotImplementedError when creating the RV without random |
| 610 | + # If the message starts with "Cannot sample from", then it uses the test_value as |
| 611 | + # the initial_val. |
| 612 | + message = ( |
| 613 | + f"Cannot sample from the DensityDist '{rv_name}' because the {method_name} " |
| 614 | + "keyword argument was not provided when the distribution was " |
| 615 | + f"but this method had not been provided when the distribution was " |
| 616 | + f"constructed. Please re-build your model and provide a callable " |
| 617 | + f"to '{rv_name}'s {method_name} keyword argument.\n" |
| 618 | + ) |
| 619 | + else: |
| 620 | + message = ( |
| 621 | + f"Attempted to run {method_name} on the DensityDist '{rv_name}', " |
| 622 | + f"but this method had not been provided when the distribution was " |
| 623 | + f"constructed. Please re-build your model and provide a callable " |
| 624 | + f"to '{rv_name}'s {method_name} keyword argument.\n" |
| 625 | + ) |
| 626 | + |
| 627 | + def func(*args, **kwargs): |
| 628 | + raise NotImplementedError(message) |
480 | 629 |
|
481 |
| - def _distr_parameters_for_repr(self): |
482 |
| - return [] |
| 630 | + return func |
0 commit comments