@@ -281,43 +281,43 @@ class ZeroSumTransform(Transform):
281
281
def __init__ (self , zerosum_axes ):
282
282
self .zerosum_axes = tuple (int (axis ) for axis in zerosum_axes )
283
283
284
+ @staticmethod
285
+ def extend_axis (array , axis ):
286
+ n = pm .floatX (array .shape [axis ] + 1 )
287
+ sum_vals = array .sum (axis , keepdims = True )
288
+ norm = sum_vals / (pt .sqrt (n ) + n )
289
+ fill_val = norm - sum_vals / pt .sqrt (n )
290
+
291
+ out = pt .concatenate ([array , fill_val ], axis = axis )
292
+ return out - norm
293
+
294
+ @staticmethod
295
+ def extend_axis_rev (array , axis ):
296
+ normalized_axis = normalize_axis_tuple (axis , array .ndim )[0 ]
297
+
298
+ n = pm .floatX (array .shape [normalized_axis ])
299
+ last = pt .take (array , [- 1 ], axis = normalized_axis )
300
+
301
+ sum_vals = - last * pt .sqrt (n )
302
+ norm = sum_vals / (pt .sqrt (n ) + n )
303
+ slice_before = (slice (None , None ),) * normalized_axis
304
+
305
+ return array [slice_before + (slice (None , - 1 ),)] + norm
306
+
284
307
def forward (self , value , * rv_inputs ):
285
308
for axis in self .zerosum_axes :
286
- value = extend_axis_rev (value , axis = axis )
309
+ value = self . extend_axis_rev (value , axis = axis )
287
310
return value
288
311
289
312
def backward (self , value , * rv_inputs ):
290
313
for axis in self .zerosum_axes :
291
- value = extend_axis (value , axis = axis )
314
+ value = self . extend_axis (value , axis = axis )
292
315
return value
293
316
294
317
def log_jac_det (self , value , * rv_inputs ):
295
318
return pt .constant (0.0 )
296
319
297
320
298
- def extend_axis (array , axis ):
299
- n = pm .floatX (array .shape [axis ] + 1 )
300
- sum_vals = array .sum (axis , keepdims = True )
301
- norm = sum_vals / (pt .sqrt (n ) + n )
302
- fill_val = norm - sum_vals / pt .sqrt (n )
303
-
304
- out = pt .concatenate ([array , fill_val ], axis = axis )
305
- return out - norm
306
-
307
-
308
- def extend_axis_rev (array , axis ):
309
- normalized_axis = normalize_axis_tuple (axis , array .ndim )[0 ]
310
-
311
- n = pm .floatX (array .shape [normalized_axis ])
312
- last = pt .take (array , [- 1 ], axis = normalized_axis )
313
-
314
- sum_vals = - last * pt .sqrt (n )
315
- norm = sum_vals / (pt .sqrt (n ) + n )
316
- slice_before = (slice (None , None ),) * normalized_axis
317
-
318
- return array [slice_before + (slice (None , - 1 ),)] + norm
319
-
320
-
321
321
log_exp_m1 = LogExpM1 ()
322
322
log_exp_m1 .__doc__ = """
323
323
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
0 commit comments