16
16
17
17
from pytensor .tensor import TensorVariable
18
18
from pytensor .tensor .random .op import RandomVariable
19
+ from pytensor .tensor .random .utils import normalize_size_param
19
20
20
21
from pymc .distributions .distribution import (
21
22
Distribution ,
22
23
SymbolicRandomVariable ,
23
24
_support_point ,
24
25
)
25
- from pymc .distributions .shape_utils import _change_dist_size , change_dist_size
26
+ from pymc .distributions .shape_utils import (
27
+ _change_dist_size ,
28
+ change_dist_size ,
29
+ implicit_size_from_params ,
30
+ rv_size_is_none ,
31
+ )
26
32
from pymc .util import check_dist_not_registered
27
33
28
34
@@ -31,9 +37,27 @@ class CensoredRV(SymbolicRandomVariable):
31
37
32
38
inline_logprob = True
33
39
signature = "(),(),()->()"
34
- ndim_supp = 0
35
40
_print_name = ("Censored" , "\\ operatorname{Censored}" )
36
41
42
+ @classmethod
43
+ def rv_op (cls , dist , lower , upper , * , size = None ):
44
+ # We don't allow passing `rng` because we don't fully control the rng of the components!
45
+ lower = pt .constant (- np .inf ) if lower is None else pt .as_tensor (lower )
46
+ upper = pt .constant (np .inf ) if upper is None else pt .as_tensor (upper )
47
+ size = normalize_size_param (size )
48
+
49
+ if rv_size_is_none (size ):
50
+ size = implicit_size_from_params (dist , lower , upper , ndims_params = cls .ndims_params )
51
+
52
+ # Censoring is achieved by clipping the base distribution between lower and upper
53
+ dist = change_dist_size (dist , size )
54
+ censored_rv = pt .clip (dist , lower , upper )
55
+
56
+ return CensoredRV (
57
+ inputs = [dist , lower , upper ],
58
+ outputs = [censored_rv ],
59
+ )(dist , lower , upper )
60
+
37
61
38
62
class Censored (Distribution ):
39
63
r"""
@@ -85,6 +109,7 @@ class Censored(Distribution):
85
109
"""
86
110
87
111
rv_type = CensoredRV
112
+ rv_op = CensoredRV .rv_op
88
113
89
114
@classmethod
90
115
def dist (cls , dist , lower , upper , ** kwargs ):
@@ -101,24 +126,6 @@ def dist(cls, dist, lower, upper, **kwargs):
101
126
check_dist_not_registered (dist )
102
127
return super ().dist ([dist , lower , upper ], ** kwargs )
103
128
104
- @classmethod
105
- def rv_op (cls , dist , lower = None , upper = None , size = None ):
106
- lower = pt .constant (- np .inf ) if lower is None else pt .as_tensor_variable (lower )
107
- upper = pt .constant (np .inf ) if upper is None else pt .as_tensor_variable (upper )
108
-
109
- # When size is not specified, dist may have to be broadcasted according to lower/upper
110
- dist_shape = size if size is not None else pt .broadcast_shape (dist , lower , upper )
111
- dist = change_dist_size (dist , dist_shape )
112
-
113
- # Censoring is achieved by clipping the base distribution between lower and upper
114
- dist_ , lower_ , upper_ = dist .type (), lower .type (), upper .type ()
115
- censored_rv_ = pt .clip (dist_ , lower_ , upper_ )
116
-
117
- return CensoredRV (
118
- inputs = [dist_ , lower_ , upper_ ],
119
- outputs = [censored_rv_ ],
120
- )(dist , lower , upper )
121
-
122
129
123
130
@_change_dist_size .register (CensoredRV )
124
131
def change_censored_size (cls , dist , new_size , expand = False ):
0 commit comments