18
18
from aesara .tensor import TensorVariable
19
19
from aesara .tensor .random .op import RandomVariable
20
20
21
+ from pymc .aesaraf import change_rv_size
21
22
from pymc .distributions .distribution import SymbolicDistribution , _moment
22
23
from pymc .util import check_dist_not_registered
23
24
@@ -74,10 +75,13 @@ def dist(cls, dist, lower, upper, **kwargs):
74
75
75
76
@classmethod
76
77
def rv_op (cls , dist , lower = None , upper = None , size = None , rngs = None ):
77
- if lower is None :
78
- lower = at .constant (- np .inf )
79
- if upper is None :
80
- upper = at .constant (np .inf )
78
+
79
+ lower = at .constant (- np .inf ) if lower is None else at .as_tensor_variable (lower )
80
+ upper = at .constant (np .inf ) if upper is None else at .as_tensor_variable (upper )
81
+
82
+ # When size is not specified, dist may have to be broadcasted according to lower/upper
83
+ dist_shape = size if size is not None else at .broadcast_shape (dist , lower , upper )
84
+ dist = change_rv_size (dist , dist_shape )
81
85
82
86
# Censoring is achieved by clipping the base distribution between lower and upper
83
87
rv_out = at .clip (dist , lower , upper )
@@ -88,8 +92,6 @@ def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
88
92
rv_out .tag .lower = lower
89
93
rv_out .tag .upper = upper
90
94
91
- if size is not None :
92
- rv_out = cls .change_size (rv_out , size )
93
95
if rngs is not None :
94
96
rv_out = cls .change_rngs (rv_out , rngs )
95
97
@@ -101,12 +103,10 @@ def ndim_supp(cls, *dist_params):
101
103
102
104
@classmethod
103
105
def change_size (cls , rv , new_size , expand = False ):
104
- dist_node = rv .tag .dist . owner
106
+ dist = rv .tag .dist
105
107
lower = rv .tag .lower
106
108
upper = rv .tag .upper
107
- rng , old_size , dtype , * dist_params = dist_node .inputs
108
- new_size = new_size if not expand else tuple (new_size ) + tuple (old_size )
109
- new_dist = dist_node .op .make_node (rng , new_size , dtype , * dist_params ).default_output ()
109
+ new_dist = change_rv_size (dist , new_size , expand = expand )
110
110
return cls .rv_op (new_dist , lower , upper )
111
111
112
112
@classmethod
0 commit comments