Skip to content

Commit c5115ee

Browse files
committed
Move ZeroSumTransform methods inside respective class
1 parent 270e840 commit c5115ee

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

pymc/distributions/transforms.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -281,43 +281,43 @@ class ZeroSumTransform(Transform):
281281
def __init__(self, zerosum_axes):
282282
self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes)
283283

284+
@staticmethod
285+
def extend_axis(array, axis):
286+
n = pm.floatX(array.shape[axis] + 1)
287+
sum_vals = array.sum(axis, keepdims=True)
288+
norm = sum_vals / (pt.sqrt(n) + n)
289+
fill_val = norm - sum_vals / pt.sqrt(n)
290+
291+
out = pt.concatenate([array, fill_val], axis=axis)
292+
return out - norm
293+
294+
@staticmethod
295+
def extend_axis_rev(array, axis):
296+
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]
297+
298+
n = pm.floatX(array.shape[normalized_axis])
299+
last = pt.take(array, [-1], axis=normalized_axis)
300+
301+
sum_vals = -last * pt.sqrt(n)
302+
norm = sum_vals / (pt.sqrt(n) + n)
303+
slice_before = (slice(None, None),) * normalized_axis
304+
305+
return array[slice_before + (slice(None, -1),)] + norm
306+
284307
def forward(self, value, *rv_inputs):
285308
for axis in self.zerosum_axes:
286-
value = extend_axis_rev(value, axis=axis)
309+
value = self.extend_axis_rev(value, axis=axis)
287310
return value
288311

289312
def backward(self, value, *rv_inputs):
290313
for axis in self.zerosum_axes:
291-
value = extend_axis(value, axis=axis)
314+
value = self.extend_axis(value, axis=axis)
292315
return value
293316

294317
def log_jac_det(self, value, *rv_inputs):
295318
return pt.constant(0.0)
296319

297320

298-
def extend_axis(array, axis):
299-
n = pm.floatX(array.shape[axis] + 1)
300-
sum_vals = array.sum(axis, keepdims=True)
301-
norm = sum_vals / (pt.sqrt(n) + n)
302-
fill_val = norm - sum_vals / pt.sqrt(n)
303-
304-
out = pt.concatenate([array, fill_val], axis=axis)
305-
return out - norm
306-
307-
308-
def extend_axis_rev(array, axis):
309-
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]
310-
311-
n = pm.floatX(array.shape[normalized_axis])
312-
last = pt.take(array, [-1], axis=normalized_axis)
313-
314-
sum_vals = -last * pt.sqrt(n)
315-
norm = sum_vals / (pt.sqrt(n) + n)
316-
slice_before = (slice(None, None),) * normalized_axis
317-
318-
return array[slice_before + (slice(None, -1),)] + norm
319-
320-
321321
log_exp_m1 = LogExpM1()
322322
log_exp_m1.__doc__ = """
323323
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`

0 commit comments

Comments
 (0)