12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from functools import singledispatch
15
+ from itertools import chain
15
16
from typing import Generator , List , Optional , Tuple , Union
16
17
17
18
import aesara .tensor as aet
31
32
]
32
33
33
34
35
+ @singledispatch
36
+ def logp_transform (op , inputs ):
37
+ return None
38
+
39
+
34
40
def _get_scaling (total_size , shape , ndim ):
35
41
"""
36
42
Gets scaling constant for logp
@@ -135,7 +141,6 @@ def change_rv_size(
135
141
136
142
def rv_log_likelihood_args (
137
143
rv_var : TensorVariable ,
138
- rv_value : Optional [TensorVariable ] = None ,
139
144
transformed : Optional [bool ] = True ,
140
145
) -> Tuple [TensorVariable , TensorVariable ]:
141
146
"""Get a `RandomVariable` and its corresponding log-likelihood `TensorVariable` value.
@@ -146,38 +151,24 @@ def rv_log_likelihood_args(
146
151
A variable corresponding to a `RandomVariable`, whether directly or
147
152
indirectly (e.g. an observed variable that's the output of an
148
153
`Observed` `Op`).
149
- rv_value
150
- The measure-space input `TensorVariable` (i.e. "input" to a
151
- log-likelihood).
152
154
transformed
153
155
When ``True``, return the transformed value var.
154
156
155
157
Returns
156
158
=======
157
159
The first value in the tuple is the `RandomVariable`, and the second is the
158
- measure-space variable that corresponds with the latter. The first is used
159
- to determine the log likelihood graph and the second is the "input"
160
- parameter to that graph. In the case of an observed `RandomVariable`, the
161
- "input" is actual data; in all other cases, it's just another
162
- `TensorVariable`.
160
+ measure-space variable that corresponds with the latter (i.e. the "value"
161
+ variable).
163
162
164
163
"""
165
164
166
- if rv_value is None :
167
- if rv_var .owner and isinstance (rv_var .owner .op , Observed ):
168
- rv_var , rv_value = rv_var .owner .inputs
169
- elif hasattr (rv_var .tag , "value_var" ):
170
- rv_value = rv_var .tag .value_var
171
- else :
172
- return rv_var , None
173
-
174
- rv_value = aet .as_tensor_variable (rv_value )
175
-
176
- transform = getattr (rv_value .tag , "transform" , None )
177
- if transformed and transform :
178
- rv_value = transform .forward (rv_value )
179
-
180
- return rv_var , rv_value
165
+ if rv_var .owner and isinstance (rv_var .owner .op , Observed ):
166
+ return tuple (rv_var .owner .inputs )
167
+ elif hasattr (rv_var .tag , "value_var" ):
168
+ rv_value = rv_var .tag .value_var
169
+ return rv_var , rv_value
170
+ else :
171
+ return rv_var , None
181
172
182
173
183
174
def rv_ancestors (graphs : List [TensorVariable ]) -> Generator [TensorVariable , None , None ]:
@@ -197,22 +188,53 @@ def strip_observed(x: TensorVariable) -> TensorVariable:
197
188
return x
198
189
199
190
200
- def sample_to_measure_vars (graphs : List [TensorVariable ]) -> List [TensorVariable ]:
201
- """Replace `RandomVariable` terms in graphs with their measure-space counterparts."""
191
+ def sample_to_measure_vars (
192
+ graphs : List [TensorVariable ],
193
+ ) -> Tuple [List [TensorVariable ], List [TensorVariable ]]:
194
+ """Replace sample-space variables in graphs with their measure-space counterparts.
195
+
196
+ Sample-space variables are `TensorVariable` outputs of `RandomVariable`
197
+ `Op`s. Measure-space variables are `TensorVariable`s that correspond to
198
+ the value of a sample-space variable in a likelihood function (e.g. ``x``
199
+ in ``p(X = x)``, where ``X`` is the corresponding sample-space variable).
200
+ (``x`` is also the variable found in ``rv_var.tag.value_var``, so this
201
+ function could also be called ``sample_to_value_vars``.)
202
+
203
+ Parameters
204
+ ==========
205
+ graphs
206
+ The graphs in which random variables are to be replaced by their
207
+ measure variables.
208
+
209
+ Returns
210
+ =======
211
+ Tuple containing the transformed graphs and a ``dict`` of the replacements
212
+ that were made.
213
+ """
202
214
replace = {}
203
- for anc in rv_ancestors (graphs ):
204
- measure_var = getattr (anc .tag , "value_var" , None )
205
- if measure_var is not None :
206
- replace [anc ] = measure_var
215
+ for anc in chain (rv_ancestors (graphs ), graphs ):
216
+
217
+ if not (anc .owner and isinstance (anc .owner .op , RandomVariable )):
218
+ continue
219
+
220
+ _ , value_var = rv_log_likelihood_args (anc )
221
+
222
+ if value_var is not None :
223
+ replace [anc ] = value_var
224
+
225
+ if replace :
226
+ measure_graphs = clone_replace (graphs , replace = replace )
227
+ else :
228
+ measure_graphs = graphs
207
229
208
- dist_params = clone_replace (graphs , replace = replace )
209
- return dist_params
230
+ return measure_graphs , replace
210
231
211
232
212
233
def logpt (
213
234
rv_var : TensorVariable ,
214
235
rv_value : Optional [TensorVariable ] = None ,
215
- jacobian : bool = True ,
236
+ jacobian : Optional [bool ] = True ,
237
+ transformed : Optional [bool ] = True ,
216
238
scaling : Optional [bool ] = True ,
217
239
** kwargs ,
218
240
) -> TensorVariable :
@@ -228,29 +250,40 @@ def logpt(
228
250
rv_var
229
251
The `RandomVariable` output that determines the log-likelihood graph.
230
252
rv_value
231
- The input variable for the log-likelihood graph.
253
+ The input variable for the log-likelihood graph. If `rv_value` is
254
+ a transformed variable, its transformations will be applied.
255
+ If no value is provided, `rv_var.tag.value_var` will be checked and,
256
+ when available, used.
232
257
jacobian
233
258
Whether or not to include the Jacobian term.
259
+ transformed
260
+ Return the transformed version of the log-likelihood graph.
234
261
scaling
235
262
A scaling term to apply to the generated log-likelihood graph.
236
263
237
264
"""
238
265
239
- rv_var , rv_value = rv_log_likelihood_args (rv_var , rv_value )
266
+ rv_var , rv_value_var = rv_log_likelihood_args (rv_var )
267
+
268
+ if rv_value is None :
269
+ rv_value = rv_value_var
270
+ else :
271
+ rv_value = aet .as_tensor (rv_value )
272
+
240
273
rv_node = rv_var .owner
241
274
242
275
if not rv_node :
243
276
raise TypeError ("rv_var must be the output of a RandomVariable Op" )
244
277
245
278
if not isinstance (rv_node .op , RandomVariable ):
246
279
280
+ # This will probably need another generic function...
247
281
if isinstance (rv_node .op , (Subtensor , AdvancedSubtensor , AdvancedSubtensor1 )):
248
282
249
283
raise NotImplementedError ("Missing value support is incomplete" )
250
284
251
285
# "Flatten" and sum an array of indexed RVs' log-likelihoods
252
286
rv_var , missing_values = rv_node .inputs
253
- rv_value = rv_var .tag .value_var
254
287
255
288
missing_values = missing_values .data
256
289
logp_var = aet .sum (
@@ -268,28 +301,40 @@ def logpt(
268
301
269
302
return aet .zeros_like (rv_var )
270
303
304
+ if rv_value_var is None :
305
+ raise NotImplementedError (f"The log-likelihood for { rv_var } is undefined" )
306
+
307
+ # This case should be reached when `rv_var` is either the result of an
308
+ # `Observed` or a `RandomVariable` `Op`
271
309
rng , size , dtype , * dist_params = rv_node .inputs
272
310
273
- dist_params = sample_to_measure_vars (dist_params )
311
+ dist_params , replacements = sample_to_measure_vars (dist_params )
274
312
275
- if jacobian :
276
- logp_var = _logp (rv_node .op , rv_value , * dist_params , ** kwargs )
277
- else :
278
- logp_var = _logp_nojac (rv_node .op , rv_value , * dist_params , ** kwargs )
313
+ logp_var = _logp (rv_node .op , rv_value_var , * dist_params , ** kwargs )
279
314
280
- # Replace `RandomVariable` ancestors with their corresponding
281
- # log-likelihood input variables
282
- lik_replacements = [
283
- (v , v .tag .value_var )
284
- for v in ancestors ([logp_var ])
285
- if v .owner and isinstance (v .owner .op , RandomVariable ) and getattr (v .tag , "value_var" , None )
286
- ]
315
+ # If any of the measure vars are transformed measure-space variables
316
+ # (signified by having a `transform` value in their tags), then we apply
317
+ # the their transforms and add their Jacobians (when enabled)
318
+ if transformed :
319
+ logp_var = transform_logp (
320
+ logp_var ,
321
+ tuple (replacements .values ()) + (rv_value_var ,),
322
+ )
323
+
324
+ transform = getattr (rv_value_var .tag , "transform" , None )
325
+
326
+ if transform and jacobian :
327
+ transformed_jacobian = transform .jacobian_det (rv_value_var )
328
+ if transformed_jacobian :
329
+ if logp_var .ndim > transformed_jacobian .ndim :
330
+ logp_var = logp_var .sum (axis = - 1 )
331
+ logp_var += transformed_jacobian
287
332
288
- (logp_var ,) = clone_replace ([logp_var ], replace = lik_replacements )
333
+ (logp_var ,) = clone_replace ([logp_var ], replace = { rv_value_var : rv_value } )
289
334
290
335
if scaling :
291
336
logp_var *= _get_scaling (
292
- getattr (rv_var .tag , "total_size" , None ), rv_value .shape , rv_value .ndim
337
+ getattr (rv_var .tag , "total_size" , None ), rv_value_var .shape , rv_value_var .ndim
293
338
)
294
339
295
340
if rv_var .name is not None :
@@ -298,6 +343,25 @@ def logpt(
298
343
return logp_var
299
344
300
345
346
+ def transform_logp (logp_var : TensorVariable , inputs : List [TensorVariable ]) -> TensorVariable :
347
+ """Transform the inputs of a log-likelihood graph."""
348
+ trans_replacements = {}
349
+ for measure_var in inputs :
350
+
351
+ transform = getattr (measure_var .tag , "transform" , None )
352
+
353
+ if transform is None :
354
+ continue
355
+
356
+ trans_rv_value = transform .backward (measure_var )
357
+ trans_replacements [measure_var ] = trans_rv_value
358
+
359
+ if trans_replacements :
360
+ (logp_var ,) = clone_replace ([logp_var ], trans_replacements )
361
+
362
+ return logp_var
363
+
364
+
301
365
@singledispatch
302
366
def _logp (op , value , * dist_params , ** kwargs ):
303
367
"""Create a log-likelihood graph.
@@ -310,20 +374,27 @@ def _logp(op, value, *dist_params, **kwargs):
310
374
return aet .zeros_like (value )
311
375
312
376
313
- def logcdf (rv_var , rv_value , ** kwargs ):
377
+ def logcdf (rv_var , rv_value , transformed = True , jacobian = True , ** kwargs ):
314
378
"""Create a log-CDF graph."""
315
379
316
- rv_var , rv_value = rv_log_likelihood_args (rv_var , rv_value )
380
+ rv_var , rv_value = rv_log_likelihood_args (rv_var )
317
381
rv_node = rv_var .owner
318
382
319
383
if not rv_node :
320
384
raise TypeError ()
321
385
322
386
rng , size , dtype , * dist_params = rv_node .inputs
323
387
324
- dist_params = sample_to_measure_vars (dist_params )
388
+ dist_params , replacements = sample_to_measure_vars (dist_params )
389
+
390
+ logp_var = _logcdf (rv_node .op , rv_value , * dist_params , ** kwargs )
325
391
326
- return _logcdf (rv_node .op , rv_value , * dist_params , ** kwargs )
392
+ if transformed :
393
+ logp_var = transform_logp (
394
+ logp_var , tuple (replacements .values ()) + (rv_value ,), jacobian = jacobian
395
+ )
396
+
397
+ return logp_var
327
398
328
399
329
400
@singledispatch
@@ -338,38 +409,6 @@ def _logcdf(op, value, *args, **kwargs):
338
409
raise NotImplementedError ()
339
410
340
411
341
- def logp_nojac (rv_var , rv_value = None , ** kwargs ):
342
- """Create a graph of the log-likelihood that doesn't include the Jacobian."""
343
-
344
- rv_var , rv_value = rv_log_likelihood_args (rv_var , rv_value )
345
- rv_node = rv_var .owner
346
-
347
- if not rv_node :
348
- raise TypeError ()
349
-
350
- rng , size , dtype , * dist_params = rv_node .inputs
351
-
352
- dist_params = sample_to_measure_vars (dist_params )
353
-
354
- return _logp_nojac (rv_node .op , rv_value , ** kwargs )
355
-
356
-
357
- @singledispatch
358
- def _logp_nojac (op , value , * args , ** kwargs ):
359
- """Return the logp, but do not include a jacobian term for transforms.
360
-
361
- If we use different parametrizations for the same distribution, we
362
- need to add the determinant of the jacobian of the transformation
363
- to make sure the densities still describe the same distribution.
364
- However, MAP estimates are not invariant with respect to the
365
- parameterization, we need to exclude the jacobian terms in this case.
366
-
367
- This function should be overwritten in base classes for transformed
368
- distributions.
369
- """
370
- return logpt (op , value , * args , ** kwargs )
371
-
372
-
373
412
def logpt_sum (rv_var : TensorVariable , rv_value : Optional [TensorVariable ] = None , ** kwargs ):
374
413
"""Return the sum of the logp values for the given observations.
375
414
0 commit comments