Skip to content

Commit 9d4a3d7

Browse files
committed
Tweaks to SymbolicRandomVariables
* Allow signature to handle rng and size arguments explicitly. * Parse ndim_supp and ndims_params from class signature * Move rv_op method to the SymbolicRandomVariable class and get rid of dummy inputs logic (it was needed in previous versions of PyTensor) * Fix errors in automatic signature of CustomDist * Allow dispatch methods without filtering of inputs for SymbolicRandomVariable distributions
1 parent d81473f commit 9d4a3d7

File tree

10 files changed

+820
-657
lines changed

10 files changed

+820
-657
lines changed

pymc/distributions/censored.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@
1616

1717
from pytensor.tensor import TensorVariable
1818
from pytensor.tensor.random.op import RandomVariable
19+
from pytensor.tensor.random.utils import normalize_size_param
1920

2021
from pymc.distributions.distribution import (
2122
Distribution,
2223
SymbolicRandomVariable,
2324
_support_point,
2425
)
25-
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
26+
from pymc.distributions.shape_utils import (
27+
_change_dist_size,
28+
change_dist_size,
29+
implicit_size_from_params,
30+
rv_size_is_none,
31+
)
2632
from pymc.util import check_dist_not_registered
2733

2834

@@ -31,9 +37,27 @@ class CensoredRV(SymbolicRandomVariable):
3137

3238
inline_logprob = True
3339
signature = "(),(),()->()"
34-
ndim_supp = 0
3540
_print_name = ("Censored", "\\operatorname{Censored}")
3641

42+
@classmethod
43+
def rv_op(cls, dist, lower, upper, *, size=None):
44+
# We don't allow passing `rng` because we don't fully control the rng of the components!
45+
lower = pt.constant(-np.inf) if lower is None else pt.as_tensor(lower)
46+
upper = pt.constant(np.inf) if upper is None else pt.as_tensor(upper)
47+
size = normalize_size_param(size)
48+
49+
if rv_size_is_none(size):
50+
size = implicit_size_from_params(dist, lower, upper, ndims_params=cls.ndims_params)
51+
52+
# Censoring is achieved by clipping the base distribution between lower and upper
53+
dist = change_dist_size(dist, size)
54+
censored_rv = pt.clip(dist, lower, upper)
55+
56+
return CensoredRV(
57+
inputs=[dist, lower, upper],
58+
outputs=[censored_rv],
59+
)(dist, lower, upper)
60+
3761

3862
class Censored(Distribution):
3963
r"""
@@ -85,6 +109,7 @@ class Censored(Distribution):
85109
"""
86110

87111
rv_type = CensoredRV
112+
rv_op = CensoredRV.rv_op
88113

89114
@classmethod
90115
def dist(cls, dist, lower, upper, **kwargs):
@@ -101,24 +126,6 @@ def dist(cls, dist, lower, upper, **kwargs):
101126
check_dist_not_registered(dist)
102127
return super().dist([dist, lower, upper], **kwargs)
103128

104-
@classmethod
105-
def rv_op(cls, dist, lower=None, upper=None, size=None):
106-
lower = pt.constant(-np.inf) if lower is None else pt.as_tensor_variable(lower)
107-
upper = pt.constant(np.inf) if upper is None else pt.as_tensor_variable(upper)
108-
109-
# When size is not specified, dist may have to be broadcasted according to lower/upper
110-
dist_shape = size if size is not None else pt.broadcast_shape(dist, lower, upper)
111-
dist = change_dist_size(dist, dist_shape)
112-
113-
# Censoring is achieved by clipping the base distribution between lower and upper
114-
dist_, lower_, upper_ = dist.type(), lower.type(), upper.type()
115-
censored_rv_ = pt.clip(dist_, lower_, upper_)
116-
117-
return CensoredRV(
118-
inputs=[dist_, lower_, upper_],
119-
outputs=[censored_rv_],
120-
)(dist, lower, upper)
121-
122129

123130
@_change_dist_size.register(CensoredRV)
124131
def change_censored_size(cls, dist, new_size, expand=False):

0 commit comments

Comments
 (0)