Skip to content

Commit 21b289a

Browse files
committed
Use dispatching for default transform instead of overriding __new__
1 parent f0cdb1f commit 21b289a

File tree

4 files changed

+93
-79
lines changed

4 files changed

+93
-79
lines changed

docs/source/contributing/developer_guide_implementing_distribution.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,8 @@ class Blah(PositiveContinuous):
193193

194194
Some notes:
195195

196-
1. A distribution should at the very least inherit from {class}`~pymc.distributions.Discrete` or {class}`~pymc.distributions.Continuous`. For the latter, more specific subclasses exist: `PositiveContinuous`, `UnitContinuous`, `BoundedContinuous`, `CircularContinuous`, which specify default transformations for the variables. If you need to specify a one-time custom transform you can also override the `__new__` method, as is done for the {class}`~pymc.distributions.multivariate.Dirichlet`.
197-
1. If a distribution does not have a corresponding `random` implementation, a `RandomVariable` should still be created that raises a `NotImplementedError`. This is the case for the {class}`~pymc.distributions.continuous.Flat`. In this case it will be necessary to provide a standard `initval` by
198-
overriding `__new__`.
196+
1. A distribution should at the very least inherit from {class}`~pymc.distributions.Discrete` or {class}`~pymc.distributions.Continuous`. For the latter, more specific subclasses exist: `PositiveContinuous`, `UnitContinuous`, `BoundedContinuous`, `CircularContinuous`, `SimplexContinuous`, which specify default transformations for the variables. If you need to specify a one-time custom transform you can also create a `_default_transform` dispatch function as is done for the {class}`~pymc.distributions.multivariate.LKJCholeskyCov`.
197+
1. If a distribution does not have a corresponding `random` implementation, a `RandomVariable` should still be created that raises a `NotImplementedError`. This is the case for the {class}`~pymc.distributions.continuous.Flat`. In this case it will be necessary to provide a `moment` method.
199198
1. As mentioned above, `PyMC` v4.x works in a very {term}`functional <Functional Programming>` way, and all the information that is needed in the `logp` and `logcdf` methods is expected to be "carried" via the `RandomVariable` inputs. You may pass numerical arguments that are not strictly needed for the `rng_fn` method but are used in the `logp` and `logcdf` methods. Just keep in mind whether this affects the correct shape inference behavior of the `RandomVariable`. If specialized non-numeric information is needed you might need to define your custom`_logp` and `_logcdf` {term}`Dispatching` functions, but this should be done as a last resort.
200199
1. The `logcdf` method is not a requirement, but it's a nice plus!
201200
1. Currently only one moment is supported in the `moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution.

pymc/distributions/bound.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from aesara.tensor.var import TensorVariable
2121

2222
from pymc.aesaraf import floatX, intX
23-
from pymc.distributions.continuous import BoundedContinuous
23+
from pymc.distributions.continuous import BoundedContinuous, bounded_cont_transform
2424
from pymc.distributions.dist_math import check_parameters
2525
from pymc.distributions.distribution import Continuous, Discrete
2626
from pymc.distributions.logprob import logp
2727
from pymc.distributions.shape_utils import to_tuple
28+
from pymc.distributions.transforms import _default_transform
2829
from pymc.model import modelcontext
2930
from pymc.util import check_dist_not_registered
3031

@@ -82,6 +83,11 @@ def logp(value, distribution, lower, upper):
8283
)
8384

8485

86+
@_default_transform.register(BoundRV)
87+
def bound_default_transform(op, rv):
88+
return bounded_cont_transform(op, rv, _ContinuousBounded.bound_args_indices)
89+
90+
8591
class DiscreteBoundRV(BoundRV):
8692
name = "discrete_bound"
8793
dtype = "int64"
@@ -94,8 +100,8 @@ class _DiscreteBounded(Discrete):
94100
rv_op = discrete_boundrv
95101

96102
def __new__(cls, *args, **kwargs):
97-
transform = kwargs.get("transform", None)
98-
if transform is not None:
103+
kwargs.setdefault("transform", None)
104+
if kwargs.get("transform") is not None:
99105
raise ValueError("Cannot transform discrete variable.")
100106
return super().__new__(cls, *args, **kwargs)
101107

pymc/distributions/continuous.py

Lines changed: 57 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def polyagamma_cdf(*args, **kwargs):
8989
from pymc.distributions.shape_utils import rv_size_is_none
9090
from pymc.distributions.transforms import _default_transform
9191
from pymc.math import invlogit, logdiffexp, logit
92-
from pymc.util import UNSET
9392

9493
__all__ = [
9594
"Uniform",
@@ -140,6 +139,13 @@ class CircularContinuous(Continuous):
140139
"""Base class for circular continuous distributions"""
141140

142141

142+
class BoundedContinuous(Continuous):
143+
"""Base class for bounded continuous distributions"""
144+
145+
# Indices of the arguments that define the lower and upper bounds of the distribution
146+
bound_args_indices: Optional[List[int]] = None
147+
148+
143149
@_default_transform.register(PositiveContinuous)
144150
def pos_cont_transform(op, rv):
145151
return transforms.log
@@ -155,48 +161,34 @@ def circ_cont_transform(op, rv):
155161
return transforms.circular
156162

157163

158-
class BoundedContinuous(Continuous):
159-
"""Base class for bounded continuous distributions"""
160-
161-
# Indices of the arguments that define the lower and upper bounds of the distribution
162-
bound_args_indices: Optional[List[int]] = None
163-
164-
def __new__(cls, *args, **kwargs):
165-
transform = kwargs.get("transform", UNSET)
166-
if transform is UNSET:
167-
kwargs["transform"] = cls.default_transform()
168-
return super().__new__(cls, *args, **kwargs)
169-
170-
@classmethod
171-
def default_transform(cls):
172-
if cls.bound_args_indices is None:
173-
raise ValueError(
174-
f"Must specify bound_args_indices for {cls.__name__} bounded distribution"
175-
)
164+
@_default_transform.register(BoundedContinuous)
165+
def bounded_cont_transform(op, rv, bound_args_indices=None):
166+
if bound_args_indices is None:
167+
raise ValueError(f"Must specify bound_args_indices for {op} bounded distribution")
176168

177-
def transform_params(*args):
169+
def transform_params(*args):
178170

179-
lower, upper = None, None
180-
if cls.bound_args_indices[0] is not None:
181-
lower = args[cls.bound_args_indices[0]]
182-
if cls.bound_args_indices[1] is not None:
183-
upper = args[cls.bound_args_indices[1]]
171+
lower, upper = None, None
172+
if bound_args_indices[0] is not None:
173+
lower = args[bound_args_indices[0]]
174+
if bound_args_indices[1] is not None:
175+
upper = args[bound_args_indices[1]]
184176

185-
if lower is not None:
186-
if isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf):
187-
lower = None
188-
else:
189-
lower = at.as_tensor_variable(lower)
177+
if lower is not None:
178+
if isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf):
179+
lower = None
180+
else:
181+
lower = at.as_tensor_variable(lower)
190182

191-
if upper is not None:
192-
if isinstance(upper, TensorConstant) and np.all(upper.value == np.inf):
193-
upper = None
194-
else:
195-
upper = at.as_tensor_variable(upper)
183+
if upper is not None:
184+
if isinstance(upper, TensorConstant) and np.all(upper.value == np.inf):
185+
upper = None
186+
else:
187+
upper = at.as_tensor_variable(upper)
196188

197-
return lower, upper
189+
return lower, upper
198190

199-
return transforms.Interval(bounds_fn=transform_params)
191+
return transforms.Interval(bounds_fn=transform_params)
200192

201193

202194
def assert_negative_support(var, label, distname, value=-1e-6):
@@ -338,6 +330,11 @@ def logcdf(value, lower, upper):
338330
)
339331

340332

333+
@_default_transform.register(Uniform)
334+
def uniform_default_transform(op, rv):
335+
return bounded_cont_transform(op, rv, Uniform.bound_args_indices)
336+
337+
341338
class FlatRV(RandomVariable):
342339
name = "flat"
343340
ndim_supp = 0
@@ -788,6 +785,11 @@ def logp(
788785
return check_parameters(logp, *bounds)
789786

790787

788+
@_default_transform.register(TruncatedNormal)
789+
def truncated_normal_default_transform(op, rv):
790+
return bounded_cont_transform(op, rv, TruncatedNormal.bound_args_indices)
791+
792+
791793
class HalfNormal(PositiveContinuous):
792794
r"""
793795
Half-normal log-likelihood.
@@ -2065,6 +2067,11 @@ def logcdf(
20652067
return check_parameters(res, 0 < alpha, 0 < m, msg="alpha > 0, m > 0")
20662068

20672069

2070+
@_default_transform.register(Pareto)
2071+
def pareto_default_transform(op, rv):
2072+
return bounded_cont_transform(op, rv, Pareto.bound_args_indices)
2073+
2074+
20682075
class Cauchy(Continuous):
20692076
r"""
20702077
Cauchy log-likelihood.
@@ -3245,6 +3252,11 @@ def logcdf(value, lower, c, upper):
32453252
)
32463253

32473254

3255+
@_default_transform.register(Triangular)
3256+
def triangular_default_transform(op, rv):
3257+
return bounded_cont_transform(op, rv, Triangular.bound_args_indices)
3258+
3259+
32483260
class Gumbel(Continuous):
32493261
r"""
32503262
Univariate Gumbel log-likelihood.
@@ -3763,17 +3775,6 @@ class Interpolated(BoundedContinuous):
37633775

37643776
rv_op = interpolated
37653777

3766-
def __new__(cls, *args, **kwargs):
3767-
transform = kwargs.get("transform", UNSET)
3768-
if transform is UNSET:
3769-
3770-
def transform_params(*params):
3771-
_, _, _, x_points, _, _ = params
3772-
return floatX(x_points[0]), floatX(x_points[-1])
3773-
3774-
kwargs["transform"] = transforms.Interval(bounds_fn=transform_params)
3775-
return super().__new__(cls, *args, **kwargs)
3776-
37773778
@classmethod
37783779
def dist(cls, x_points, pdf_points, *args, **kwargs):
37793780

@@ -3827,8 +3828,14 @@ def logp(value, x_points, pdf_points, cdf_points):
38273828

38283829
return at.log(interp_op(value) / Z)
38293830

3830-
def _distr_parameters_for_repr(self):
3831-
return []
3831+
3832+
@_default_transform.register(Interpolated)
3833+
def interpolated_default_transform(op, rv):
3834+
def transform_params(*params):
3835+
_, _, _, x_points, _, _ = params
3836+
return floatX(x_points[0]), floatX(x_points[-1])
3837+
3838+
return transforms.Interval(bounds_fn=transform_params)
38323839

38333840

38343841
class MoyalRV(RandomVariable):

pymc/distributions/multivariate.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@
6262
rv_size_is_none,
6363
to_tuple,
6464
)
65-
from pymc.distributions.transforms import Interval
65+
from pymc.distributions.transforms import Interval, _default_transform
6666
from pymc.math import kron_diag, kron_dot
67-
from pymc.util import UNSET, check_dist_not_registered
67+
from pymc.util import check_dist_not_registered
6868

6969
__all__ = [
7070
"MvNormal",
@@ -83,6 +83,16 @@
8383
"StickBreakingWeights",
8484
]
8585

86+
87+
class SimplexContinuous(Continuous):
88+
"""Base class for simplex continuous distributions"""
89+
90+
91+
@_default_transform.register(SimplexContinuous)
92+
def simplex_cont_transform(op, rv):
93+
return transforms.simplex
94+
95+
8696
# Step methods and advi do not catch LinAlgErrors at the
8797
# moment. We work around that by using a cholesky op
8898
# that returns a nan as first entry instead of raising
@@ -408,7 +418,7 @@ def logp(value, nu, mu, cov):
408418
)
409419

410420

411-
class Dirichlet(Continuous):
421+
class Dirichlet(SimplexContinuous):
412422
r"""
413423
Dirichlet log-likelihood.
414424
@@ -434,10 +444,6 @@ class Dirichlet(Continuous):
434444
"""
435445
rv_op = dirichlet
436446

437-
def __new__(cls, name, *args, **kwargs):
438-
kwargs.setdefault("transform", transforms.simplex)
439-
return super().__new__(cls, name, *args, **kwargs)
440-
441447
@classmethod
442448
def dist(cls, a, **kwargs):
443449
a = at.as_tensor_variable(a)
@@ -1169,12 +1175,7 @@ class _LKJCholeskyCov(Continuous):
11691175
rv_op = _ljk_cholesky_cov
11701176

11711177
def __new__(cls, name, eta, n, sd_dist, **kwargs):
1172-
transform = kwargs.get("transform", UNSET)
1173-
if transform is UNSET:
1174-
kwargs["transform"] = transforms.CholeskyCovPacked(n)
1175-
11761178
check_dist_not_registered(sd_dist)
1177-
11781179
return super().__new__(cls, name, eta, n, sd_dist, **kwargs)
11791180

11801181
@classmethod
@@ -1269,6 +1270,12 @@ def logp(value, n, eta, sd_dist):
12691270
return norm + logp_lkj + logp_sd + det_invjac
12701271

12711272

1273+
@_default_transform.register(_LKJCholeskyCov)
1274+
def lkjcholeskycov_default_transform(op, rv):
1275+
_, _, _, n, _, _ = rv.owner.inputs
1276+
return transforms.CholeskyCovPacked(n)
1277+
1278+
12721279
class LKJCholeskyCov:
12731280
r"""Wrapper class for covariance matrix with LKJ distributed correlations.
12741281
@@ -1551,12 +1558,6 @@ class LKJCorr(BoundedContinuous):
15511558

15521559
rv_op = lkjcorr
15531560

1554-
def __new__(cls, *args, **kwargs):
1555-
transform = kwargs.get("transform", UNSET)
1556-
if transform is UNSET:
1557-
kwargs["transform"] = Interval(floatX(-1.0), floatX(1.0))
1558-
return super().__new__(cls, *args, **kwargs)
1559-
15601561
@classmethod
15611562
def dist(cls, n, eta, **kwargs):
15621563
n = at.as_tensor_variable(intX(n))
@@ -1610,6 +1611,11 @@ def logp(value, n, eta):
16101611
)
16111612

16121613

1614+
@_default_transform.register(LKJCorr)
1615+
def lkjcorr_default_transform(op, rv):
1616+
return Interval(floatX(-1.0), floatX(1.0))
1617+
1618+
16131619
class MatrixNormalRV(RandomVariable):
16141620
name = "matrixnormal"
16151621
ndim_supp = 2
@@ -2261,7 +2267,7 @@ def rng_fn(cls, rng, alpha, K, size):
22612267
stickbreakingweights = StickBreakingWeightsRV()
22622268

22632269

2264-
class StickBreakingWeights(Continuous):
2270+
class StickBreakingWeights(SimplexContinuous):
22652271
r"""
22662272
Likelihood of truncated stick-breaking weights. The weights are generated from a
22672273
stick-breaking proceduce where :math:`x_k = v_k \prod_{\ell < k} (1 - v_\ell)` for
@@ -2298,10 +2304,6 @@ class StickBreakingWeights(Continuous):
22982304
"""
22992305
rv_op = stickbreakingweights
23002306

2301-
def __new__(cls, name, *args, **kwargs):
2302-
kwargs.setdefault("transform", transforms.simplex)
2303-
return super().__new__(cls, name, *args, **kwargs)
2304-
23052307
@classmethod
23062308
def dist(cls, alpha, K, *args, **kwargs):
23072309
alpha = at.as_tensor_variable(floatX(alpha))

0 commit comments

Comments
 (0)