@@ -106,28 +106,34 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
106
106
]
107
107
graph_inputs = [* rv_inputs , lower , upper ]
108
108
109
- rv = dist .owner .op .make_node (* rv_inputs ).default_output ()
109
+ # Variables with `_` suffix identify dummy inputs for the OpFromGraph
110
+ graph_inputs_ = [
111
+ inp .type () if not isinstance (inp .type , RandomType ) else inp for inp in graph_inputs
112
+ ]
113
+ * rv_inputs_ , lower_ , upper_ = graph_inputs_
114
+
115
+ rv_ = dist .owner .op .make_node (* rv_inputs_ ).default_output ()
110
116
111
117
# Try to use inverted cdf sampling
112
118
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
113
119
try :
114
- logcdf_lower , logcdf_upper = cls ._create_logcdf_exprs (rv , rv , lower , upper )
120
+ logcdf_lower_ , logcdf_upper_ = TruncatedRV ._create_logcdf_exprs (
121
+ rv_ , rv_ , lower_ , upper_
122
+ )
115
123
# We use the first RNG from the base RV, so we don't have to introduce a new one
116
124
# This is not problematic because the RNG won't be used in the RV logcdf graph
117
- uniform_rng = next (inp for inp in rv_inputs if isinstance (inp .type , RandomType ))
118
- uniform_next_rng , uniform = pt .random .uniform (
119
- pt .exp (logcdf_lower ),
120
- pt .exp (logcdf_upper ),
121
- rng = uniform_rng ,
122
- size = rv .shape ,
125
+ uniform_rng_ = next (inp_ for inp_ in rv_inputs_ if isinstance (inp_ .type , RandomType ))
126
+ uniform_next_rng_ , uniform_ = pt .random .uniform (
127
+ pt .exp (logcdf_lower_ ),
128
+ pt .exp (logcdf_upper_ ),
129
+ rng = uniform_rng_ ,
130
+ size = rv_ .shape ,
123
131
).owner .outputs
124
- # So icdf does not see the random graph of uniform
125
- uniform_type = uniform .type ()
126
- truncated_rv = graph_replace (icdf (rv , uniform_type ), {uniform_type : uniform })
132
+ truncated_rv_ = icdf (rv_ , uniform_ , warn_rvs = False )
127
133
return TruncatedRV (
128
134
base_rv_op = dist .owner .op ,
129
- inputs = graph_inputs ,
130
- outputs = [truncated_rv , uniform_next_rng ],
135
+ inputs = graph_inputs_ ,
136
+ outputs = [truncated_rv_ , uniform_next_rng_ ],
131
137
ndim_supp = 0 ,
132
138
max_n_steps = max_n_steps ,
133
139
)(* graph_inputs )
@@ -154,25 +160,25 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
154
160
155
161
return (
156
162
(truncated_rv , reject_draws ),
157
- collect_default_updates (new_truncated_rv , inputs = rv_inputs ),
163
+ collect_default_updates (new_truncated_rv ),
158
164
until (~ pt .any (reject_draws )),
159
165
)
160
166
161
- (truncated_rv , reject_draws_ ), updates = scan (
167
+ (truncated_rv_ , reject_draws_ ), updates = scan (
162
168
loop_fn ,
163
169
outputs_info = [
164
- pt .zeros_like (rv ),
165
- pt .ones_like (rv , dtype = bool ),
170
+ pt .zeros_like (rv_ ),
171
+ pt .ones_like (rv_ , dtype = bool ),
166
172
],
167
- non_sequences = [lower , upper , * rv_inputs ],
173
+ non_sequences = [lower_ , upper_ , * rv_inputs_ ],
168
174
n_steps = max_n_steps ,
169
175
strict = True ,
170
176
)
171
177
172
- truncated_rv = truncated_rv [- 1 ]
173
- convergence = ~ pt .any (reject_draws_ [- 1 ])
174
- truncated_rv = TruncationCheck (f"Truncation did not converge in { max_n_steps } steps" )(
175
- truncated_rv , convergence
178
+ truncated_rv_ = truncated_rv_ [- 1 ]
179
+ convergence_ = ~ pt .any (reject_draws_ [- 1 ])
180
+ truncated_rv_ = TruncationCheck (f"Truncation did not converge in { max_n_steps } steps" )(
181
+ truncated_rv_ , convergence_
176
182
)
177
183
178
184
# Sort updates of each RNG so that they show in the same order as the input RNGs
@@ -184,8 +190,8 @@ def sort_updates(update):
184
190
185
191
return TruncatedRV (
186
192
base_rv_op = dist .owner .op ,
187
- inputs = graph_inputs ,
188
- outputs = [truncated_rv , * next_rngs ],
193
+ inputs = graph_inputs_ ,
194
+ outputs = [truncated_rv_ , * next_rngs ],
189
195
ndim_supp = 0 ,
190
196
max_n_steps = max_n_steps ,
191
197
)(* graph_inputs )
0 commit comments